Spaces:
Runtime error
Runtime error
Commit
·
0fc4c70
0
Parent(s):
Duplicate from bluefoxcreation/SwapMukham
Browse filesCo-authored-by: BlueFox <[email protected]>
- .gitattributes +35 -0
- .gitignore +2 -0
- README.md +14 -0
- app.py +924 -0
- assets/images/logo.png +0 -0
- assets/pretrained_models/79999_iter.pth +3 -0
- assets/pretrained_models/GFPGANv1.4.pth +3 -0
- assets/pretrained_models/RealESRGAN_x2.pth +3 -0
- assets/pretrained_models/RealESRGAN_x4.pth +3 -0
- assets/pretrained_models/RealESRGAN_x8.pth +3 -0
- assets/pretrained_models/codeformer.onnx +3 -0
- assets/pretrained_models/inswapper_128.onnx +3 -0
- assets/pretrained_models/open-nsfw.onnx +3 -0
- assets/pretrained_models/readme.md +4 -0
- face_analyser.py +194 -0
- face_enhancer.py +72 -0
- face_parsing/__init__.py +3 -0
- face_parsing/model.py +283 -0
- face_parsing/parse_mask.py +107 -0
- face_parsing/resnet.py +109 -0
- face_parsing/swap.py +133 -0
- face_swapper.py +150 -0
- gfpgan/weights/detection_Resnet50_Final.pth +3 -0
- gfpgan/weights/parsing_parsenet.pth +3 -0
- nsfw_checker/LICENSE.md +11 -0
- nsfw_checker/__init__.py +1 -0
- nsfw_checker/opennsfw.py +37 -0
- nsfw_detector.py +65 -0
- requirements.txt +12 -0
- upscaler/RealESRGAN/__init__.py +1 -0
- upscaler/RealESRGAN/arch_utils.py +197 -0
- upscaler/RealESRGAN/model.py +90 -0
- upscaler/RealESRGAN/rrdbnet_arch.py +121 -0
- upscaler/RealESRGAN/utils.py +133 -0
- upscaler/__init__.py +0 -0
- upscaler/codeformer.py +37 -0
- utils.py +303 -0
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
|
2 |
+
*.pyc
|
README.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Swap Mukham
|
3 |
+
emoji: 💻
|
4 |
+
colorFrom: purple
|
5 |
+
colorTo: gray
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.35.2
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: unknown
|
11 |
+
duplicated_from: bluefoxcreation/SwapMukham
|
12 |
+
---
|
13 |
+
|
14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,924 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import glob
|
4 |
+
import time
|
5 |
+
import torch
|
6 |
+
import shutil
|
7 |
+
import argparse
|
8 |
+
import platform
|
9 |
+
import datetime
|
10 |
+
import subprocess
|
11 |
+
import insightface
|
12 |
+
import onnxruntime
|
13 |
+
import numpy as np
|
14 |
+
import gradio as gr
|
15 |
+
import threading
|
16 |
+
import queue
|
17 |
+
from tqdm import tqdm
|
18 |
+
import concurrent.futures
|
19 |
+
from moviepy.editor import VideoFileClip
|
20 |
+
|
21 |
+
from nsfw_checker import NSFWChecker
|
22 |
+
from face_swapper import Inswapper, paste_to_whole
|
23 |
+
from face_analyser import detect_conditions, get_analysed_data, swap_options_list
|
24 |
+
from face_parsing import init_parsing_model, get_parsed_mask, mask_regions, mask_regions_to_list
|
25 |
+
from face_enhancer import get_available_enhancer_names, load_face_enhancer_model, cv2_interpolations
|
26 |
+
from utils import trim_video, StreamerThread, ProcessBar, open_directory, split_list_by_lengths, merge_img_sequence_from_ref, create_image_grid
|
27 |
+
|
28 |
+
## ------------------------------ USER ARGS ------------------------------
|
29 |
+
|
30 |
+
parser = argparse.ArgumentParser(description="Swap-Mukham Face Swapper")
|
31 |
+
parser.add_argument("--out_dir", help="Default Output directory", default=os.getcwd())
|
32 |
+
parser.add_argument("--batch_size", help="Gpu batch size", default=32)
|
33 |
+
parser.add_argument("--cuda", action="store_true", help="Enable cuda", default=False)
|
34 |
+
parser.add_argument(
|
35 |
+
"--colab", action="store_true", help="Enable colab mode", default=False
|
36 |
+
)
|
37 |
+
user_args = parser.parse_args()
|
38 |
+
|
39 |
+
## ------------------------------ DEFAULTS ------------------------------
|
40 |
+
|
41 |
+
USE_COLAB = user_args.colab
|
42 |
+
USE_CUDA = user_args.cuda
|
43 |
+
DEF_OUTPUT_PATH = user_args.out_dir
|
44 |
+
BATCH_SIZE = int(user_args.batch_size)
|
45 |
+
WORKSPACE = None
|
46 |
+
OUTPUT_FILE = None
|
47 |
+
CURRENT_FRAME = None
|
48 |
+
STREAMER = None
|
49 |
+
DETECT_CONDITION = "best detection"
|
50 |
+
DETECT_SIZE = 640
|
51 |
+
DETECT_THRESH = 0.6
|
52 |
+
NUM_OF_SRC_SPECIFIC = 10
|
53 |
+
MASK_INCLUDE = [
|
54 |
+
"Skin",
|
55 |
+
"R-Eyebrow",
|
56 |
+
"L-Eyebrow",
|
57 |
+
"L-Eye",
|
58 |
+
"R-Eye",
|
59 |
+
"Nose",
|
60 |
+
"Mouth",
|
61 |
+
"L-Lip",
|
62 |
+
"U-Lip"
|
63 |
+
]
|
64 |
+
MASK_SOFT_KERNEL = 17
|
65 |
+
MASK_SOFT_ITERATIONS = 10
|
66 |
+
MASK_BLUR_AMOUNT = 0.1
|
67 |
+
MASK_ERODE_AMOUNT = 0.15
|
68 |
+
|
69 |
+
FACE_SWAPPER = None
|
70 |
+
FACE_ANALYSER = None
|
71 |
+
FACE_ENHANCER = None
|
72 |
+
FACE_PARSER = None
|
73 |
+
NSFW_DETECTOR = None
|
74 |
+
FACE_ENHANCER_LIST = ["NONE"]
|
75 |
+
FACE_ENHANCER_LIST.extend(get_available_enhancer_names())
|
76 |
+
FACE_ENHANCER_LIST.extend(cv2_interpolations)
|
77 |
+
|
78 |
+
## ------------------------------ SET EXECUTION PROVIDER ------------------------------
|
79 |
+
# Note: Non CUDA users may change settings here
|
80 |
+
|
81 |
+
PROVIDER = ["CPUExecutionProvider"]
|
82 |
+
|
83 |
+
if USE_CUDA:
|
84 |
+
available_providers = onnxruntime.get_available_providers()
|
85 |
+
if "CUDAExecutionProvider" in available_providers:
|
86 |
+
print("\n********** Running on CUDA **********\n")
|
87 |
+
PROVIDER = ["CUDAExecutionProvider", "CPUExecutionProvider"]
|
88 |
+
else:
|
89 |
+
USE_CUDA = False
|
90 |
+
print("\n********** CUDA unavailable running on CPU **********\n")
|
91 |
+
else:
|
92 |
+
USE_CUDA = False
|
93 |
+
print("\n********** Running on CPU **********\n")
|
94 |
+
|
95 |
+
device = "cuda" if USE_CUDA else "cpu"
|
96 |
+
EMPTY_CACHE = lambda: torch.cuda.empty_cache() if device == "cuda" else None
|
97 |
+
|
98 |
+
## ------------------------------ LOAD MODELS ------------------------------
|
99 |
+
|
100 |
+
def load_face_analyser_model(name="buffalo_l"):
|
101 |
+
global FACE_ANALYSER
|
102 |
+
if FACE_ANALYSER is None:
|
103 |
+
FACE_ANALYSER = insightface.app.FaceAnalysis(name=name, providers=PROVIDER)
|
104 |
+
FACE_ANALYSER.prepare(
|
105 |
+
ctx_id=0, det_size=(DETECT_SIZE, DETECT_SIZE), det_thresh=DETECT_THRESH
|
106 |
+
)
|
107 |
+
|
108 |
+
|
109 |
+
def load_face_swapper_model(path="./assets/pretrained_models/inswapper_128.onnx"):
|
110 |
+
global FACE_SWAPPER
|
111 |
+
if FACE_SWAPPER is None:
|
112 |
+
batch = int(BATCH_SIZE) if device == "cuda" else 1
|
113 |
+
FACE_SWAPPER = Inswapper(model_file=path, batch_size=batch, providers=PROVIDER)
|
114 |
+
|
115 |
+
|
116 |
+
def load_face_parser_model(path="./assets/pretrained_models/79999_iter.pth"):
|
117 |
+
global FACE_PARSER
|
118 |
+
if FACE_PARSER is None:
|
119 |
+
FACE_PARSER = init_parsing_model(path, device=device)
|
120 |
+
|
121 |
+
def load_nsfw_detector_model(path="./assets/pretrained_models/open-nsfw.onnx"):
|
122 |
+
global NSFW_DETECTOR
|
123 |
+
if NSFW_DETECTOR is None:
|
124 |
+
NSFW_DETECTOR = NSFWChecker(model_path=path, providers=PROVIDER)
|
125 |
+
|
126 |
+
|
127 |
+
load_face_analyser_model()
|
128 |
+
load_face_swapper_model()
|
129 |
+
|
130 |
+
## ------------------------------ MAIN PROCESS ------------------------------
|
131 |
+
|
132 |
+
|
133 |
+
def process(
|
134 |
+
input_type,
|
135 |
+
image_path,
|
136 |
+
video_path,
|
137 |
+
directory_path,
|
138 |
+
source_path,
|
139 |
+
output_path,
|
140 |
+
output_name,
|
141 |
+
keep_output_sequence,
|
142 |
+
condition,
|
143 |
+
age,
|
144 |
+
distance,
|
145 |
+
face_enhancer_name,
|
146 |
+
enable_face_parser,
|
147 |
+
mask_includes,
|
148 |
+
mask_soft_kernel,
|
149 |
+
mask_soft_iterations,
|
150 |
+
blur_amount,
|
151 |
+
erode_amount,
|
152 |
+
face_scale,
|
153 |
+
enable_laplacian_blend,
|
154 |
+
crop_top,
|
155 |
+
crop_bott,
|
156 |
+
crop_left,
|
157 |
+
crop_right,
|
158 |
+
*specifics,
|
159 |
+
):
|
160 |
+
global WORKSPACE
|
161 |
+
global OUTPUT_FILE
|
162 |
+
global PREVIEW
|
163 |
+
WORKSPACE, OUTPUT_FILE, PREVIEW = None, None, None
|
164 |
+
|
165 |
+
## ------------------------------ GUI UPDATE FUNC ------------------------------
|
166 |
+
|
167 |
+
def ui_before():
|
168 |
+
return (
|
169 |
+
gr.update(visible=True, value=PREVIEW),
|
170 |
+
gr.update(interactive=False),
|
171 |
+
gr.update(interactive=False),
|
172 |
+
gr.update(visible=False),
|
173 |
+
)
|
174 |
+
|
175 |
+
def ui_after():
|
176 |
+
return (
|
177 |
+
gr.update(visible=True, value=PREVIEW),
|
178 |
+
gr.update(interactive=True),
|
179 |
+
gr.update(interactive=True),
|
180 |
+
gr.update(visible=False),
|
181 |
+
)
|
182 |
+
|
183 |
+
def ui_after_vid():
|
184 |
+
return (
|
185 |
+
gr.update(visible=False),
|
186 |
+
gr.update(interactive=True),
|
187 |
+
gr.update(interactive=True),
|
188 |
+
gr.update(value=OUTPUT_FILE, visible=True),
|
189 |
+
)
|
190 |
+
|
191 |
+
start_time = time.time()
|
192 |
+
total_exec_time = lambda start_time: divmod(time.time() - start_time, 60)
|
193 |
+
get_finsh_text = lambda start_time: f"✔️ Completed in {int(total_exec_time(start_time)[0])} min {int(total_exec_time(start_time)[1])} sec."
|
194 |
+
|
195 |
+
## ------------------------------ PREPARE INPUTS & LOAD MODELS ------------------------------
|
196 |
+
|
197 |
+
yield "### \n ⌛ Loading NSFW detector model...", *ui_before()
|
198 |
+
load_nsfw_detector_model()
|
199 |
+
|
200 |
+
yield "### \n ⌛ Loading face analyser model...", *ui_before()
|
201 |
+
load_face_analyser_model()
|
202 |
+
|
203 |
+
yield "### \n ⌛ Loading face swapper model...", *ui_before()
|
204 |
+
load_face_swapper_model()
|
205 |
+
|
206 |
+
if face_enhancer_name != "NONE":
|
207 |
+
if face_enhancer_name not in cv2_interpolations:
|
208 |
+
yield f"### \n ⌛ Loading {face_enhancer_name} model...", *ui_before()
|
209 |
+
FACE_ENHANCER = load_face_enhancer_model(name=face_enhancer_name, device=device)
|
210 |
+
else:
|
211 |
+
FACE_ENHANCER = None
|
212 |
+
|
213 |
+
if enable_face_parser:
|
214 |
+
yield "### \n ⌛ Loading face parsing model...", *ui_before()
|
215 |
+
load_face_parser_model()
|
216 |
+
|
217 |
+
includes = mask_regions_to_list(mask_includes)
|
218 |
+
specifics = list(specifics)
|
219 |
+
half = len(specifics) // 2
|
220 |
+
sources = specifics[:half]
|
221 |
+
specifics = specifics[half:]
|
222 |
+
if crop_top > crop_bott:
|
223 |
+
crop_top, crop_bott = crop_bott, crop_top
|
224 |
+
if crop_left > crop_right:
|
225 |
+
crop_left, crop_right = crop_right, crop_left
|
226 |
+
crop_mask = (crop_top, 511-crop_bott, crop_left, 511-crop_right)
|
227 |
+
|
228 |
+
def swap_process(image_sequence):
|
229 |
+
## ------------------------------ CONTENT CHECK ------------------------------
|
230 |
+
|
231 |
+
yield "### \n ⌛ Checking contents...", *ui_before()
|
232 |
+
nsfw = NSFW_DETECTOR.is_nsfw(image_sequence)
|
233 |
+
if nsfw:
|
234 |
+
message = "NSFW Content detected !!!"
|
235 |
+
yield f"### \n 🔞 {message}", *ui_before()
|
236 |
+
assert not nsfw, message
|
237 |
+
return False
|
238 |
+
EMPTY_CACHE()
|
239 |
+
|
240 |
+
## ------------------------------ ANALYSE FACE ------------------------------
|
241 |
+
|
242 |
+
yield "### \n ⌛ Analysing face data...", *ui_before()
|
243 |
+
if condition != "Specific Face":
|
244 |
+
source_data = source_path, age
|
245 |
+
else:
|
246 |
+
source_data = ((sources, specifics), distance)
|
247 |
+
analysed_targets, analysed_sources, whole_frame_list, num_faces_per_frame = get_analysed_data(
|
248 |
+
FACE_ANALYSER,
|
249 |
+
image_sequence,
|
250 |
+
source_data,
|
251 |
+
swap_condition=condition,
|
252 |
+
detect_condition=DETECT_CONDITION,
|
253 |
+
scale=face_scale
|
254 |
+
)
|
255 |
+
|
256 |
+
## ------------------------------ SWAP FUNC ------------------------------
|
257 |
+
|
258 |
+
yield "### \n ⌛ Generating faces...", *ui_before()
|
259 |
+
preds = []
|
260 |
+
matrs = []
|
261 |
+
count = 0
|
262 |
+
global PREVIEW
|
263 |
+
for batch_pred, batch_matr in FACE_SWAPPER.batch_forward(whole_frame_list, analysed_targets, analysed_sources):
|
264 |
+
preds.extend(batch_pred)
|
265 |
+
matrs.extend(batch_matr)
|
266 |
+
EMPTY_CACHE()
|
267 |
+
count += 1
|
268 |
+
|
269 |
+
if USE_CUDA:
|
270 |
+
image_grid = create_image_grid(batch_pred, size=128)
|
271 |
+
PREVIEW = image_grid[:, :, ::-1]
|
272 |
+
yield f"### \n ⌛ Generating face Batch {count}", *ui_before()
|
273 |
+
|
274 |
+
## ------------------------------ FACE ENHANCEMENT ------------------------------
|
275 |
+
|
276 |
+
generated_len = len(preds)
|
277 |
+
if face_enhancer_name != "NONE":
|
278 |
+
yield f"### \n ⌛ Upscaling faces with {face_enhancer_name}...", *ui_before()
|
279 |
+
for idx, pred in tqdm(enumerate(preds), total=generated_len, desc=f"Upscaling with {face_enhancer_name}"):
|
280 |
+
enhancer_model, enhancer_model_runner = FACE_ENHANCER
|
281 |
+
pred = enhancer_model_runner(pred, enhancer_model)
|
282 |
+
preds[idx] = cv2.resize(pred, (512,512))
|
283 |
+
EMPTY_CACHE()
|
284 |
+
|
285 |
+
## ------------------------------ FACE PARSING ------------------------------
|
286 |
+
|
287 |
+
if enable_face_parser:
|
288 |
+
yield "### \n ⌛ Face-parsing mask...", *ui_before()
|
289 |
+
masks = []
|
290 |
+
count = 0
|
291 |
+
for batch_mask in get_parsed_mask(FACE_PARSER, preds, classes=includes, device=device, batch_size=BATCH_SIZE, softness=int(mask_soft_iterations)):
|
292 |
+
masks.append(batch_mask)
|
293 |
+
EMPTY_CACHE()
|
294 |
+
count += 1
|
295 |
+
|
296 |
+
if len(batch_mask) > 1:
|
297 |
+
image_grid = create_image_grid(batch_mask, size=128)
|
298 |
+
PREVIEW = image_grid[:, :, ::-1]
|
299 |
+
yield f"### \n ⌛ Face parsing Batch {count}", *ui_before()
|
300 |
+
masks = np.concatenate(masks, axis=0) if len(masks) >= 1 else masks
|
301 |
+
else:
|
302 |
+
masks = [None] * generated_len
|
303 |
+
|
304 |
+
## ------------------------------ SPLIT LIST ------------------------------
|
305 |
+
|
306 |
+
split_preds = split_list_by_lengths(preds, num_faces_per_frame)
|
307 |
+
del preds
|
308 |
+
split_matrs = split_list_by_lengths(matrs, num_faces_per_frame)
|
309 |
+
del matrs
|
310 |
+
split_masks = split_list_by_lengths(masks, num_faces_per_frame)
|
311 |
+
del masks
|
312 |
+
|
313 |
+
## ------------------------------ PASTE-BACK ------------------------------
|
314 |
+
|
315 |
+
yield "### \n ⌛ Pasting back...", *ui_before()
|
316 |
+
def post_process(frame_idx, frame_img, split_preds, split_matrs, split_masks, enable_laplacian_blend, crop_mask, blur_amount, erode_amount):
|
317 |
+
whole_img_path = frame_img
|
318 |
+
whole_img = cv2.imread(whole_img_path)
|
319 |
+
blend_method = 'laplacian' if enable_laplacian_blend else 'linear'
|
320 |
+
for p, m, mask in zip(split_preds[frame_idx], split_matrs[frame_idx], split_masks[frame_idx]):
|
321 |
+
p = cv2.resize(p, (512,512))
|
322 |
+
mask = cv2.resize(mask, (512,512)) if mask is not None else None
|
323 |
+
m /= 0.25
|
324 |
+
whole_img = paste_to_whole(p, whole_img, m, mask=mask, crop_mask=crop_mask, blend_method=blend_method, blur_amount=blur_amount, erode_amount=erode_amount)
|
325 |
+
cv2.imwrite(whole_img_path, whole_img)
|
326 |
+
|
327 |
+
def concurrent_post_process(image_sequence, *args):
|
328 |
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
329 |
+
futures = []
|
330 |
+
for idx, frame_img in enumerate(image_sequence):
|
331 |
+
future = executor.submit(post_process, idx, frame_img, *args)
|
332 |
+
futures.append(future)
|
333 |
+
|
334 |
+
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Pasting back"):
|
335 |
+
result = future.result()
|
336 |
+
|
337 |
+
concurrent_post_process(
|
338 |
+
image_sequence,
|
339 |
+
split_preds,
|
340 |
+
split_matrs,
|
341 |
+
split_masks,
|
342 |
+
enable_laplacian_blend,
|
343 |
+
crop_mask,
|
344 |
+
blur_amount,
|
345 |
+
erode_amount
|
346 |
+
)
|
347 |
+
|
348 |
+
|
349 |
+
## ------------------------------ IMAGE ------------------------------
|
350 |
+
|
351 |
+
if input_type == "Image":
|
352 |
+
target = cv2.imread(image_path)
|
353 |
+
output_file = os.path.join(output_path, output_name + ".png")
|
354 |
+
cv2.imwrite(output_file, target)
|
355 |
+
|
356 |
+
for info_update in swap_process([output_file]):
|
357 |
+
yield info_update
|
358 |
+
|
359 |
+
OUTPUT_FILE = output_file
|
360 |
+
WORKSPACE = output_path
|
361 |
+
PREVIEW = cv2.imread(output_file)[:, :, ::-1]
|
362 |
+
|
363 |
+
yield get_finsh_text(start_time), *ui_after()
|
364 |
+
|
365 |
+
## ------------------------------ VIDEO ------------------------------
|
366 |
+
|
367 |
+
elif input_type == "Video":
|
368 |
+
temp_path = os.path.join(output_path, output_name, "sequence")
|
369 |
+
os.makedirs(temp_path, exist_ok=True)
|
370 |
+
|
371 |
+
yield "### \n ⌛ Extracting video frames...", *ui_before()
|
372 |
+
image_sequence = []
|
373 |
+
cap = cv2.VideoCapture(video_path)
|
374 |
+
curr_idx = 0
|
375 |
+
while True:
|
376 |
+
ret, frame = cap.read()
|
377 |
+
if not ret:break
|
378 |
+
frame_path = os.path.join(temp_path, f"frame_{curr_idx}.jpg")
|
379 |
+
cv2.imwrite(frame_path, frame)
|
380 |
+
image_sequence.append(frame_path)
|
381 |
+
curr_idx += 1
|
382 |
+
cap.release()
|
383 |
+
cv2.destroyAllWindows()
|
384 |
+
|
385 |
+
for info_update in swap_process(image_sequence):
|
386 |
+
yield info_update
|
387 |
+
|
388 |
+
yield "### \n ⌛ Merging sequence...", *ui_before()
|
389 |
+
output_video_path = os.path.join(output_path, output_name + ".mp4")
|
390 |
+
merge_img_sequence_from_ref(video_path, image_sequence, output_video_path)
|
391 |
+
|
392 |
+
if os.path.exists(temp_path) and not keep_output_sequence:
|
393 |
+
yield "### \n ⌛ Removing temporary files...", *ui_before()
|
394 |
+
shutil.rmtree(temp_path)
|
395 |
+
|
396 |
+
WORKSPACE = output_path
|
397 |
+
OUTPUT_FILE = output_video_path
|
398 |
+
|
399 |
+
yield get_finsh_text(start_time), *ui_after_vid()
|
400 |
+
|
401 |
+
## ------------------------------ DIRECTORY ------------------------------
|
402 |
+
|
403 |
+
elif input_type == "Directory":
|
404 |
+
extensions = ["jpg", "jpeg", "png", "bmp", "tiff", "ico", "webp"]
|
405 |
+
temp_path = os.path.join(output_path, output_name)
|
406 |
+
if os.path.exists(temp_path):
|
407 |
+
shutil.rmtree(temp_path)
|
408 |
+
os.mkdir(temp_path)
|
409 |
+
|
410 |
+
file_paths =[]
|
411 |
+
for file_path in glob.glob(os.path.join(directory_path, "*")):
|
412 |
+
if any(file_path.lower().endswith(ext) for ext in extensions):
|
413 |
+
img = cv2.imread(file_path)
|
414 |
+
new_file_path = os.path.join(temp_path, os.path.basename(file_path))
|
415 |
+
cv2.imwrite(new_file_path, img)
|
416 |
+
file_paths.append(new_file_path)
|
417 |
+
|
418 |
+
for info_update in swap_process(file_paths):
|
419 |
+
yield info_update
|
420 |
+
|
421 |
+
PREVIEW = cv2.imread(file_paths[-1])[:, :, ::-1]
|
422 |
+
WORKSPACE = temp_path
|
423 |
+
OUTPUT_FILE = file_paths[-1]
|
424 |
+
|
425 |
+
yield get_finsh_text(start_time), *ui_after()
|
426 |
+
|
427 |
+
## ------------------------------ STREAM ------------------------------
|
428 |
+
|
429 |
+
elif input_type == "Stream":
|
430 |
+
pass
|
431 |
+
|
432 |
+
|
433 |
+
## ------------------------------ GRADIO FUNC ------------------------------
|
434 |
+
|
435 |
+
|
436 |
+
def update_radio(value):
|
437 |
+
if value == "Image":
|
438 |
+
return (
|
439 |
+
gr.update(visible=True),
|
440 |
+
gr.update(visible=False),
|
441 |
+
gr.update(visible=False),
|
442 |
+
)
|
443 |
+
elif value == "Video":
|
444 |
+
return (
|
445 |
+
gr.update(visible=False),
|
446 |
+
gr.update(visible=True),
|
447 |
+
gr.update(visible=False),
|
448 |
+
)
|
449 |
+
elif value == "Directory":
|
450 |
+
return (
|
451 |
+
gr.update(visible=False),
|
452 |
+
gr.update(visible=False),
|
453 |
+
gr.update(visible=True),
|
454 |
+
)
|
455 |
+
elif value == "Stream":
|
456 |
+
return (
|
457 |
+
gr.update(visible=False),
|
458 |
+
gr.update(visible=False),
|
459 |
+
gr.update(visible=True),
|
460 |
+
)
|
461 |
+
|
462 |
+
|
463 |
+
def swap_option_changed(value):
|
464 |
+
if value.startswith("Age"):
|
465 |
+
return (
|
466 |
+
gr.update(visible=True),
|
467 |
+
gr.update(visible=False),
|
468 |
+
gr.update(visible=True),
|
469 |
+
)
|
470 |
+
elif value == "Specific Face":
|
471 |
+
return (
|
472 |
+
gr.update(visible=False),
|
473 |
+
gr.update(visible=True),
|
474 |
+
gr.update(visible=False),
|
475 |
+
)
|
476 |
+
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
|
477 |
+
|
478 |
+
|
479 |
+
def video_changed(video_path):
|
480 |
+
sliders_update = gr.Slider.update
|
481 |
+
button_update = gr.Button.update
|
482 |
+
number_update = gr.Number.update
|
483 |
+
|
484 |
+
if video_path is None:
|
485 |
+
return (
|
486 |
+
sliders_update(minimum=0, maximum=0, value=0),
|
487 |
+
sliders_update(minimum=1, maximum=1, value=1),
|
488 |
+
number_update(value=1),
|
489 |
+
)
|
490 |
+
try:
|
491 |
+
clip = VideoFileClip(video_path)
|
492 |
+
fps = clip.fps
|
493 |
+
total_frames = clip.reader.nframes
|
494 |
+
clip.close()
|
495 |
+
return (
|
496 |
+
sliders_update(minimum=0, maximum=total_frames, value=0, interactive=True),
|
497 |
+
sliders_update(
|
498 |
+
minimum=0, maximum=total_frames, value=total_frames, interactive=True
|
499 |
+
),
|
500 |
+
number_update(value=fps),
|
501 |
+
)
|
502 |
+
except:
|
503 |
+
return (
|
504 |
+
sliders_update(value=0),
|
505 |
+
sliders_update(value=0),
|
506 |
+
number_update(value=1),
|
507 |
+
)
|
508 |
+
|
509 |
+
|
510 |
+
def analyse_settings_changed(detect_condition, detection_size, detection_threshold):
|
511 |
+
yield "### \n ⌛ Applying new values..."
|
512 |
+
global FACE_ANALYSER
|
513 |
+
global DETECT_CONDITION
|
514 |
+
DETECT_CONDITION = detect_condition
|
515 |
+
FACE_ANALYSER = insightface.app.FaceAnalysis(name="buffalo_l", providers=PROVIDER)
|
516 |
+
FACE_ANALYSER.prepare(
|
517 |
+
ctx_id=0,
|
518 |
+
det_size=(int(detection_size), int(detection_size)),
|
519 |
+
det_thresh=float(detection_threshold),
|
520 |
+
)
|
521 |
+
yield f"### \n ✔️ Applied detect condition:{detect_condition}, detection size: {detection_size}, detection threshold: {detection_threshold}"
|
522 |
+
|
523 |
+
|
524 |
+
def stop_running():
|
525 |
+
global STREAMER
|
526 |
+
if hasattr(STREAMER, "stop"):
|
527 |
+
STREAMER.stop()
|
528 |
+
STREAMER = None
|
529 |
+
return "Cancelled"
|
530 |
+
|
531 |
+
|
532 |
+
def slider_changed(show_frame, video_path, frame_index):
|
533 |
+
if not show_frame:
|
534 |
+
return None, None
|
535 |
+
if video_path is None:
|
536 |
+
return None, None
|
537 |
+
clip = VideoFileClip(video_path)
|
538 |
+
frame = clip.get_frame(frame_index / clip.fps)
|
539 |
+
frame_array = np.array(frame)
|
540 |
+
clip.close()
|
541 |
+
return gr.Image.update(value=frame_array, visible=True), gr.Video.update(
|
542 |
+
visible=False
|
543 |
+
)
|
544 |
+
|
545 |
+
|
546 |
+
def trim_and_reload(video_path, output_path, output_name, start_frame, stop_frame):
|
547 |
+
yield video_path, f"### \n ⌛ Trimming video frame {start_frame} to {stop_frame}..."
|
548 |
+
try:
|
549 |
+
output_path = os.path.join(output_path, output_name)
|
550 |
+
trimmed_video = trim_video(video_path, output_path, start_frame, stop_frame)
|
551 |
+
yield trimmed_video, "### \n ✔️ Video trimmed and reloaded."
|
552 |
+
except Exception as e:
|
553 |
+
print(e)
|
554 |
+
yield video_path, "### \n ❌ Video trimming failed. See console for more info."
|
555 |
+
|
556 |
+
|
557 |
+
## ------------------------------ GRADIO GUI ------------------------------
|
558 |
+
|
559 |
+
css = """
|
560 |
+
footer{display:none !important}
|
561 |
+
"""
|
562 |
+
|
563 |
+
with gr.Blocks(css=css) as interface:
|
564 |
+
gr.Markdown("# 🗿 Swap Mukham")
|
565 |
+
gr.Markdown("### Face swap app based on insightface inswapper.")
|
566 |
+
with gr.Row():
|
567 |
+
with gr.Row():
|
568 |
+
with gr.Column(scale=0.4):
|
569 |
+
with gr.Tab("📄 Swap Condition"):
|
570 |
+
swap_option = gr.Dropdown(
|
571 |
+
swap_options_list,
|
572 |
+
info="Choose which face or faces in the target image to swap.",
|
573 |
+
multiselect=False,
|
574 |
+
show_label=False,
|
575 |
+
value=swap_options_list[0],
|
576 |
+
interactive=True,
|
577 |
+
)
|
578 |
+
age = gr.Number(
|
579 |
+
value=25, label="Value", interactive=True, visible=False
|
580 |
+
)
|
581 |
+
|
582 |
+
with gr.Tab("🎚️ Detection Settings"):
|
583 |
+
detect_condition_dropdown = gr.Dropdown(
|
584 |
+
detect_conditions,
|
585 |
+
label="Condition",
|
586 |
+
value=DETECT_CONDITION,
|
587 |
+
interactive=True,
|
588 |
+
info="This condition is only used when multiple faces are detected on source or specific image.",
|
589 |
+
)
|
590 |
+
detection_size = gr.Number(
|
591 |
+
label="Detection Size", value=DETECT_SIZE, interactive=True
|
592 |
+
)
|
593 |
+
detection_threshold = gr.Number(
|
594 |
+
label="Detection Threshold",
|
595 |
+
value=DETECT_THRESH,
|
596 |
+
interactive=True,
|
597 |
+
)
|
598 |
+
apply_detection_settings = gr.Button("Apply settings")
|
599 |
+
|
600 |
+
with gr.Tab("📤 Output Settings"):
|
601 |
+
output_directory = gr.Text(
|
602 |
+
label="Output Directory",
|
603 |
+
value=DEF_OUTPUT_PATH,
|
604 |
+
interactive=True,
|
605 |
+
)
|
606 |
+
output_name = gr.Text(
|
607 |
+
label="Output Name", value="Result", interactive=True
|
608 |
+
)
|
609 |
+
keep_output_sequence = gr.Checkbox(
|
610 |
+
label="Keep output sequence", value=False, interactive=True
|
611 |
+
)
|
612 |
+
|
613 |
+
with gr.Tab("🪄 Other Settings"):
|
614 |
+
face_scale = gr.Slider(
|
615 |
+
label="Face Scale",
|
616 |
+
minimum=0,
|
617 |
+
maximum=2,
|
618 |
+
value=1,
|
619 |
+
interactive=True,
|
620 |
+
)
|
621 |
+
|
622 |
+
face_enhancer_name = gr.Dropdown(
|
623 |
+
FACE_ENHANCER_LIST, label="Face Enhancer", value="NONE", multiselect=False, interactive=True
|
624 |
+
)
|
625 |
+
|
626 |
+
with gr.Accordion("Advanced Mask", open=False):
|
627 |
+
enable_face_parser_mask = gr.Checkbox(
|
628 |
+
label="Enable Face Parsing",
|
629 |
+
value=False,
|
630 |
+
interactive=True,
|
631 |
+
)
|
632 |
+
|
633 |
+
mask_include = gr.Dropdown(
|
634 |
+
mask_regions.keys(),
|
635 |
+
value=MASK_INCLUDE,
|
636 |
+
multiselect=True,
|
637 |
+
label="Include",
|
638 |
+
interactive=True,
|
639 |
+
)
|
640 |
+
mask_soft_kernel = gr.Number(
|
641 |
+
label="Soft Erode Kernel",
|
642 |
+
value=MASK_SOFT_KERNEL,
|
643 |
+
minimum=3,
|
644 |
+
interactive=True,
|
645 |
+
visible = False
|
646 |
+
)
|
647 |
+
mask_soft_iterations = gr.Number(
|
648 |
+
label="Soft Erode Iterations",
|
649 |
+
value=MASK_SOFT_ITERATIONS,
|
650 |
+
minimum=0,
|
651 |
+
interactive=True,
|
652 |
+
|
653 |
+
)
|
654 |
+
|
655 |
+
|
656 |
+
with gr.Accordion("Crop Mask", open=False):
|
657 |
+
crop_top = gr.Slider(label="Top", minimum=0, maximum=511, value=0, step=1, interactive=True)
|
658 |
+
crop_bott = gr.Slider(label="Bottom", minimum=0, maximum=511, value=511, step=1, interactive=True)
|
659 |
+
crop_left = gr.Slider(label="Left", minimum=0, maximum=511, value=0, step=1, interactive=True)
|
660 |
+
crop_right = gr.Slider(label="Right", minimum=0, maximum=511, value=511, step=1, interactive=True)
|
661 |
+
|
662 |
+
|
663 |
+
erode_amount = gr.Slider(
|
664 |
+
label="Mask Erode",
|
665 |
+
minimum=0,
|
666 |
+
maximum=1,
|
667 |
+
value=MASK_ERODE_AMOUNT,
|
668 |
+
step=0.05,
|
669 |
+
interactive=True,
|
670 |
+
)
|
671 |
+
|
672 |
+
blur_amount = gr.Slider(
|
673 |
+
label="Mask Blur",
|
674 |
+
minimum=0,
|
675 |
+
maximum=1,
|
676 |
+
value=MASK_BLUR_AMOUNT,
|
677 |
+
step=0.05,
|
678 |
+
interactive=True,
|
679 |
+
)
|
680 |
+
|
681 |
+
enable_laplacian_blend = gr.Checkbox(
|
682 |
+
label="Laplacian Blending",
|
683 |
+
value=True,
|
684 |
+
interactive=True,
|
685 |
+
)
|
686 |
+
|
687 |
+
|
688 |
+
source_image_input = gr.Image(
|
689 |
+
label="Source face", type="filepath", interactive=True
|
690 |
+
)
|
691 |
+
|
692 |
+
with gr.Box(visible=False) as specific_face:
|
693 |
+
for i in range(NUM_OF_SRC_SPECIFIC):
|
694 |
+
idx = i + 1
|
695 |
+
code = "\n"
|
696 |
+
code += f"with gr.Tab(label='({idx})'):"
|
697 |
+
code += "\n\twith gr.Row():"
|
698 |
+
code += f"\n\t\tsrc{idx} = gr.Image(interactive=True, type='numpy', label='Source Face {idx}')"
|
699 |
+
code += f"\n\t\ttrg{idx} = gr.Image(interactive=True, type='numpy', label='Specific Face {idx}')"
|
700 |
+
exec(code)
|
701 |
+
|
702 |
+
distance_slider = gr.Slider(
|
703 |
+
minimum=0,
|
704 |
+
maximum=2,
|
705 |
+
value=0.6,
|
706 |
+
interactive=True,
|
707 |
+
label="Distance",
|
708 |
+
info="Lower distance is more similar and higher distance is less similar to the target face.",
|
709 |
+
)
|
710 |
+
|
711 |
+
with gr.Group():
|
712 |
+
input_type = gr.Radio(
|
713 |
+
["Image", "Video"],
|
714 |
+
label="Target Type",
|
715 |
+
value="Image",
|
716 |
+
)
|
717 |
+
|
718 |
+
with gr.Box(visible=True) as input_image_group:
|
719 |
+
image_input = gr.Image(
|
720 |
+
label="Target Image", interactive=True, type="filepath"
|
721 |
+
)
|
722 |
+
|
723 |
+
with gr.Box(visible=False) as input_video_group:
|
724 |
+
vid_widget = gr.Video if USE_COLAB else gr.Text
|
725 |
+
video_input = gr.Video(
|
726 |
+
label="Target Video", interactive=True
|
727 |
+
)
|
728 |
+
with gr.Accordion("✂️ Trim video", open=False):
|
729 |
+
with gr.Column():
|
730 |
+
with gr.Row():
|
731 |
+
set_slider_range_btn = gr.Button(
|
732 |
+
"Set frame range", interactive=True
|
733 |
+
)
|
734 |
+
show_trim_preview_btn = gr.Checkbox(
|
735 |
+
label="Show frame when slider change",
|
736 |
+
value=True,
|
737 |
+
interactive=True,
|
738 |
+
)
|
739 |
+
|
740 |
+
video_fps = gr.Number(
|
741 |
+
value=30,
|
742 |
+
interactive=False,
|
743 |
+
label="Fps",
|
744 |
+
visible=False,
|
745 |
+
)
|
746 |
+
start_frame = gr.Slider(
|
747 |
+
minimum=0,
|
748 |
+
maximum=1,
|
749 |
+
value=0,
|
750 |
+
step=1,
|
751 |
+
interactive=True,
|
752 |
+
label="Start Frame",
|
753 |
+
info="",
|
754 |
+
)
|
755 |
+
end_frame = gr.Slider(
|
756 |
+
minimum=0,
|
757 |
+
maximum=1,
|
758 |
+
value=1,
|
759 |
+
step=1,
|
760 |
+
interactive=True,
|
761 |
+
label="End Frame",
|
762 |
+
info="",
|
763 |
+
)
|
764 |
+
trim_and_reload_btn = gr.Button(
|
765 |
+
"Trim and Reload", interactive=True
|
766 |
+
)
|
767 |
+
|
768 |
+
with gr.Box(visible=False) as input_directory_group:
|
769 |
+
direc_input = gr.Text(label="Path", interactive=True)
|
770 |
+
|
771 |
+
with gr.Column(scale=0.6):
|
772 |
+
info = gr.Markdown(value="...")
|
773 |
+
|
774 |
+
with gr.Row():
|
775 |
+
swap_button = gr.Button("✨ Swap", variant="primary")
|
776 |
+
cancel_button = gr.Button("⛔ Cancel")
|
777 |
+
|
778 |
+
preview_image = gr.Image(label="Output", interactive=False)
|
779 |
+
preview_video = gr.Video(
|
780 |
+
label="Output", interactive=False, visible=False
|
781 |
+
)
|
782 |
+
|
783 |
+
with gr.Row():
|
784 |
+
output_directory_button = gr.Button(
|
785 |
+
"📂", interactive=False, visible=False
|
786 |
+
)
|
787 |
+
output_video_button = gr.Button(
|
788 |
+
"🎬", interactive=False, visible=False
|
789 |
+
)
|
790 |
+
|
791 |
+
with gr.Box():
|
792 |
+
with gr.Row():
|
793 |
+
gr.Markdown(
|
794 |
+
"### [🤝 Sponsor](https://github.com/sponsors/harisreedhar)"
|
795 |
+
)
|
796 |
+
gr.Markdown(
|
797 |
+
"### [👨💻 Source code](https://github.com/harisreedhar/Swap-Mukham)"
|
798 |
+
)
|
799 |
+
gr.Markdown(
|
800 |
+
"### [⚠️ Disclaimer](https://github.com/harisreedhar/Swap-Mukham#disclaimer)"
|
801 |
+
)
|
802 |
+
gr.Markdown(
|
803 |
+
"### [🌐 Run in Colab](https://colab.research.google.com/github/harisreedhar/Swap-Mukham/blob/main/swap_mukham_colab.ipynb)"
|
804 |
+
)
|
805 |
+
gr.Markdown(
|
806 |
+
"### [🤗 Acknowledgements](https://github.com/harisreedhar/Swap-Mukham#acknowledgements)"
|
807 |
+
)
|
808 |
+
|
809 |
+
## ------------------------------ GRADIO EVENTS ------------------------------
|
810 |
+
|
811 |
+
set_slider_range_event = set_slider_range_btn.click(
|
812 |
+
video_changed,
|
813 |
+
inputs=[video_input],
|
814 |
+
outputs=[start_frame, end_frame, video_fps],
|
815 |
+
)
|
816 |
+
|
817 |
+
trim_and_reload_event = trim_and_reload_btn.click(
|
818 |
+
fn=trim_and_reload,
|
819 |
+
inputs=[video_input, output_directory, output_name, start_frame, end_frame],
|
820 |
+
outputs=[video_input, info],
|
821 |
+
)
|
822 |
+
|
823 |
+
start_frame_event = start_frame.release(
|
824 |
+
fn=slider_changed,
|
825 |
+
inputs=[show_trim_preview_btn, video_input, start_frame],
|
826 |
+
outputs=[preview_image, preview_video],
|
827 |
+
show_progress=True,
|
828 |
+
)
|
829 |
+
|
830 |
+
end_frame_event = end_frame.release(
|
831 |
+
fn=slider_changed,
|
832 |
+
inputs=[show_trim_preview_btn, video_input, end_frame],
|
833 |
+
outputs=[preview_image, preview_video],
|
834 |
+
show_progress=True,
|
835 |
+
)
|
836 |
+
|
837 |
+
input_type.change(
|
838 |
+
update_radio,
|
839 |
+
inputs=[input_type],
|
840 |
+
outputs=[input_image_group, input_video_group, input_directory_group],
|
841 |
+
)
|
842 |
+
swap_option.change(
|
843 |
+
swap_option_changed,
|
844 |
+
inputs=[swap_option],
|
845 |
+
outputs=[age, specific_face, source_image_input],
|
846 |
+
)
|
847 |
+
|
848 |
+
apply_detection_settings.click(
|
849 |
+
analyse_settings_changed,
|
850 |
+
inputs=[detect_condition_dropdown, detection_size, detection_threshold],
|
851 |
+
outputs=[info],
|
852 |
+
)
|
853 |
+
|
854 |
+
src_specific_inputs = []
|
855 |
+
gen_variable_txt = ",".join(
|
856 |
+
[f"src{i+1}" for i in range(NUM_OF_SRC_SPECIFIC)]
|
857 |
+
+ [f"trg{i+1}" for i in range(NUM_OF_SRC_SPECIFIC)]
|
858 |
+
)
|
859 |
+
exec(f"src_specific_inputs = ({gen_variable_txt})")
|
860 |
+
swap_inputs = [
|
861 |
+
input_type,
|
862 |
+
image_input,
|
863 |
+
video_input,
|
864 |
+
direc_input,
|
865 |
+
source_image_input,
|
866 |
+
output_directory,
|
867 |
+
output_name,
|
868 |
+
keep_output_sequence,
|
869 |
+
swap_option,
|
870 |
+
age,
|
871 |
+
distance_slider,
|
872 |
+
face_enhancer_name,
|
873 |
+
enable_face_parser_mask,
|
874 |
+
mask_include,
|
875 |
+
mask_soft_kernel,
|
876 |
+
mask_soft_iterations,
|
877 |
+
blur_amount,
|
878 |
+
erode_amount,
|
879 |
+
face_scale,
|
880 |
+
enable_laplacian_blend,
|
881 |
+
crop_top,
|
882 |
+
crop_bott,
|
883 |
+
crop_left,
|
884 |
+
crop_right,
|
885 |
+
*src_specific_inputs,
|
886 |
+
]
|
887 |
+
|
888 |
+
swap_outputs = [
|
889 |
+
info,
|
890 |
+
preview_image,
|
891 |
+
output_directory_button,
|
892 |
+
output_video_button,
|
893 |
+
preview_video,
|
894 |
+
]
|
895 |
+
|
896 |
+
swap_event = swap_button.click(
|
897 |
+
fn=process, inputs=swap_inputs, outputs=swap_outputs, show_progress=True
|
898 |
+
)
|
899 |
+
|
900 |
+
cancel_button.click(
|
901 |
+
fn=stop_running,
|
902 |
+
inputs=None,
|
903 |
+
outputs=[info],
|
904 |
+
cancels=[
|
905 |
+
swap_event,
|
906 |
+
trim_and_reload_event,
|
907 |
+
set_slider_range_event,
|
908 |
+
start_frame_event,
|
909 |
+
end_frame_event,
|
910 |
+
],
|
911 |
+
show_progress=True,
|
912 |
+
)
|
913 |
+
output_directory_button.click(
|
914 |
+
lambda: open_directory(path=WORKSPACE), inputs=None, outputs=None
|
915 |
+
)
|
916 |
+
output_video_button.click(
|
917 |
+
lambda: open_directory(path=OUTPUT_FILE), inputs=None, outputs=None
|
918 |
+
)
|
919 |
+
|
920 |
+
if __name__ == "__main__":
|
921 |
+
if USE_COLAB:
|
922 |
+
print("Running in colab mode")
|
923 |
+
|
924 |
+
interface.queue(concurrency_count=2, max_size=20).launch(share=USE_COLAB)
|
assets/images/logo.png
ADDED
![]() |
assets/pretrained_models/79999_iter.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:468e13ca13a9b43cc0881a9f99083a430e9c0a38abd935431d1c28ee94b26567
|
3 |
+
size 53289463
|
assets/pretrained_models/GFPGANv1.4.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e2cd4703ab14f4d01fd1383a8a8b266f9a5833dacee8e6a79d3bf21a1b6be5ad
|
3 |
+
size 348632874
|
assets/pretrained_models/RealESRGAN_x2.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c830d067d54fc767b9543a8432f36d91bc2de313584e8bbfe4ac26a47339e899
|
3 |
+
size 67061725
|
assets/pretrained_models/RealESRGAN_x4.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:aa00f09ad753d88576b21ed977e97d634976377031b178acc3b5b238df463400
|
3 |
+
size 67040989
|
assets/pretrained_models/RealESRGAN_x8.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8b72fb469d12f05a4770813d2603eb1b550f40df6fb8b37d6c7bc2db3d2bff5e
|
3 |
+
size 67189359
|
assets/pretrained_models/codeformer.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:91e7e881c5001fea4a535e8f96eaeaa672d30c963a678a3e27f0429a6620f57a
|
3 |
+
size 376821950
|
assets/pretrained_models/inswapper_128.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e4a3f08c753cb72d04e10aa0f7dbe3deebbf39567d4ead6dce08e98aa49e16af
|
3 |
+
size 554253681
|
assets/pretrained_models/open-nsfw.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:864bb37bf8863564b87eb330ab8c785a79a773f4e7c43cb96db52ed8611305fa
|
3 |
+
size 23590724
|
assets/pretrained_models/readme.md
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Downolad these models here
|
2 |
+
- [inswapper_128.onnx](https://huggingface.co/deepinsight/inswapper/resolve/main/inswapper_128.onnx)
|
3 |
+
- [GFPGANv1.4.pth](https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth)
|
4 |
+
- [79999_iter.pth](https://drive.google.com/open?id=154JgKpzCPW82qINcVieuPH3fZ2e0P812)
|
face_analyser.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
from tqdm import tqdm
|
5 |
+
from utils import scale_bbox_from_center
|
6 |
+
|
7 |
+
detect_conditions = [
|
8 |
+
"best detection",
|
9 |
+
"left most",
|
10 |
+
"right most",
|
11 |
+
"top most",
|
12 |
+
"bottom most",
|
13 |
+
"middle",
|
14 |
+
"biggest",
|
15 |
+
"smallest",
|
16 |
+
]
|
17 |
+
|
18 |
+
swap_options_list = [
|
19 |
+
"All Face",
|
20 |
+
"Specific Face",
|
21 |
+
"Age less than",
|
22 |
+
"Age greater than",
|
23 |
+
"All Male",
|
24 |
+
"All Female",
|
25 |
+
"Left Most",
|
26 |
+
"Right Most",
|
27 |
+
"Top Most",
|
28 |
+
"Bottom Most",
|
29 |
+
"Middle",
|
30 |
+
"Biggest",
|
31 |
+
"Smallest",
|
32 |
+
]
|
33 |
+
|
34 |
+
def get_single_face(faces, method="best detection"):
|
35 |
+
total_faces = len(faces)
|
36 |
+
if total_faces == 1:
|
37 |
+
return faces[0]
|
38 |
+
|
39 |
+
print(f"{total_faces} face detected. Using {method} face.")
|
40 |
+
if method == "best detection":
|
41 |
+
return sorted(faces, key=lambda face: face["det_score"])[-1]
|
42 |
+
elif method == "left most":
|
43 |
+
return sorted(faces, key=lambda face: face["bbox"][0])[0]
|
44 |
+
elif method == "right most":
|
45 |
+
return sorted(faces, key=lambda face: face["bbox"][0])[-1]
|
46 |
+
elif method == "top most":
|
47 |
+
return sorted(faces, key=lambda face: face["bbox"][1])[0]
|
48 |
+
elif method == "bottom most":
|
49 |
+
return sorted(faces, key=lambda face: face["bbox"][1])[-1]
|
50 |
+
elif method == "middle":
|
51 |
+
return sorted(faces, key=lambda face: (
|
52 |
+
(face["bbox"][0] + face["bbox"][2]) / 2 - 0.5) ** 2 +
|
53 |
+
((face["bbox"][1] + face["bbox"][3]) / 2 - 0.5) ** 2)[len(faces) // 2]
|
54 |
+
elif method == "biggest":
|
55 |
+
return sorted(faces, key=lambda face: (face["bbox"][2] - face["bbox"][0]) * (face["bbox"][3] - face["bbox"][1]))[-1]
|
56 |
+
elif method == "smallest":
|
57 |
+
return sorted(faces, key=lambda face: (face["bbox"][2] - face["bbox"][0]) * (face["bbox"][3] - face["bbox"][1]))[0]
|
58 |
+
|
59 |
+
|
60 |
+
def analyse_face(image, model, return_single_face=True, detect_condition="best detection", scale=1.0):
|
61 |
+
faces = model.get(image)
|
62 |
+
if scale != 1: # landmark-scale
|
63 |
+
for i, face in enumerate(faces):
|
64 |
+
landmark = face['kps']
|
65 |
+
center = np.mean(landmark, axis=0)
|
66 |
+
landmark = center + (landmark - center) * scale
|
67 |
+
faces[i]['kps'] = landmark
|
68 |
+
|
69 |
+
if not return_single_face:
|
70 |
+
return faces
|
71 |
+
|
72 |
+
return get_single_face(faces, method=detect_condition)
|
73 |
+
|
74 |
+
|
75 |
+
def cosine_distance(a, b):
|
76 |
+
a /= np.linalg.norm(a)
|
77 |
+
b /= np.linalg.norm(b)
|
78 |
+
return 1 - np.dot(a, b)
|
79 |
+
|
80 |
+
|
81 |
+
def get_analysed_data(face_analyser, image_sequence, source_data, swap_condition="All face", detect_condition="left most", scale=1.0):
|
82 |
+
if swap_condition != "Specific Face":
|
83 |
+
source_path, age = source_data
|
84 |
+
source_image = cv2.imread(source_path)
|
85 |
+
analysed_source = analyse_face(source_image, face_analyser, return_single_face=True, detect_condition=detect_condition, scale=scale)
|
86 |
+
else:
|
87 |
+
analysed_source_specifics = []
|
88 |
+
source_specifics, threshold = source_data
|
89 |
+
for source, specific in zip(*source_specifics):
|
90 |
+
if source is None or specific is None:
|
91 |
+
continue
|
92 |
+
analysed_source = analyse_face(source, face_analyser, return_single_face=True, detect_condition=detect_condition, scale=scale)
|
93 |
+
analysed_specific = analyse_face(specific, face_analyser, return_single_face=True, detect_condition=detect_condition, scale=scale)
|
94 |
+
analysed_source_specifics.append([analysed_source, analysed_specific])
|
95 |
+
|
96 |
+
analysed_target_list = []
|
97 |
+
analysed_source_list = []
|
98 |
+
whole_frame_eql_list = []
|
99 |
+
num_faces_per_frame = []
|
100 |
+
|
101 |
+
total_frames = len(image_sequence)
|
102 |
+
curr_idx = 0
|
103 |
+
for curr_idx, frame_path in tqdm(enumerate(image_sequence), total=total_frames, desc="Analysing face data"):
|
104 |
+
frame = cv2.imread(frame_path)
|
105 |
+
analysed_faces = analyse_face(frame, face_analyser, return_single_face=False, detect_condition=detect_condition, scale=scale)
|
106 |
+
|
107 |
+
n_faces = 0
|
108 |
+
for analysed_face in analysed_faces:
|
109 |
+
if swap_condition == "All Face":
|
110 |
+
analysed_target_list.append(analysed_face)
|
111 |
+
analysed_source_list.append(analysed_source)
|
112 |
+
whole_frame_eql_list.append(frame_path)
|
113 |
+
n_faces += 1
|
114 |
+
elif swap_condition == "Age less than" and analysed_face["age"] < age:
|
115 |
+
analysed_target_list.append(analysed_face)
|
116 |
+
analysed_source_list.append(analysed_source)
|
117 |
+
whole_frame_eql_list.append(frame_path)
|
118 |
+
n_faces += 1
|
119 |
+
elif swap_condition == "Age greater than" and analysed_face["age"] > age:
|
120 |
+
analysed_target_list.append(analysed_face)
|
121 |
+
analysed_source_list.append(analysed_source)
|
122 |
+
whole_frame_eql_list.append(frame_path)
|
123 |
+
n_faces += 1
|
124 |
+
elif swap_condition == "All Male" and analysed_face["gender"] == 1:
|
125 |
+
analysed_target_list.append(analysed_face)
|
126 |
+
analysed_source_list.append(analysed_source)
|
127 |
+
whole_frame_eql_list.append(frame_path)
|
128 |
+
n_faces += 1
|
129 |
+
elif swap_condition == "All Female" and analysed_face["gender"] == 0:
|
130 |
+
analysed_target_list.append(analysed_face)
|
131 |
+
analysed_source_list.append(analysed_source)
|
132 |
+
whole_frame_eql_list.append(frame_path)
|
133 |
+
n_faces += 1
|
134 |
+
elif swap_condition == "Specific Face":
|
135 |
+
for analysed_source, analysed_specific in analysed_source_specifics:
|
136 |
+
distance = cosine_distance(analysed_specific["embedding"], analysed_face["embedding"])
|
137 |
+
if distance < threshold:
|
138 |
+
analysed_target_list.append(analysed_face)
|
139 |
+
analysed_source_list.append(analysed_source)
|
140 |
+
whole_frame_eql_list.append(frame_path)
|
141 |
+
n_faces += 1
|
142 |
+
|
143 |
+
if swap_condition == "Left Most":
|
144 |
+
analysed_face = get_single_face(analysed_faces, method="left most")
|
145 |
+
analysed_target_list.append(analysed_face)
|
146 |
+
analysed_source_list.append(analysed_source)
|
147 |
+
whole_frame_eql_list.append(frame_path)
|
148 |
+
n_faces += 1
|
149 |
+
|
150 |
+
elif swap_condition == "Right Most":
|
151 |
+
analysed_face = get_single_face(analysed_faces, method="right most")
|
152 |
+
analysed_target_list.append(analysed_face)
|
153 |
+
analysed_source_list.append(analysed_source)
|
154 |
+
whole_frame_eql_list.append(frame_path)
|
155 |
+
n_faces += 1
|
156 |
+
|
157 |
+
elif swap_condition == "Top Most":
|
158 |
+
analysed_face = get_single_face(analysed_faces, method="top most")
|
159 |
+
analysed_target_list.append(analysed_face)
|
160 |
+
analysed_source_list.append(analysed_source)
|
161 |
+
whole_frame_eql_list.append(frame_path)
|
162 |
+
n_faces += 1
|
163 |
+
|
164 |
+
elif swap_condition == "Bottom Most":
|
165 |
+
analysed_face = get_single_face(analysed_faces, method="bottom most")
|
166 |
+
analysed_target_list.append(analysed_face)
|
167 |
+
analysed_source_list.append(analysed_source)
|
168 |
+
whole_frame_eql_list.append(frame_path)
|
169 |
+
n_faces += 1
|
170 |
+
|
171 |
+
elif swap_condition == "Middle":
|
172 |
+
analysed_face = get_single_face(analysed_faces, method="middle")
|
173 |
+
analysed_target_list.append(analysed_face)
|
174 |
+
analysed_source_list.append(analysed_source)
|
175 |
+
whole_frame_eql_list.append(frame_path)
|
176 |
+
n_faces += 1
|
177 |
+
|
178 |
+
elif swap_condition == "Biggest":
|
179 |
+
analysed_face = get_single_face(analysed_faces, method="biggest")
|
180 |
+
analysed_target_list.append(analysed_face)
|
181 |
+
analysed_source_list.append(analysed_source)
|
182 |
+
whole_frame_eql_list.append(frame_path)
|
183 |
+
n_faces += 1
|
184 |
+
|
185 |
+
elif swap_condition == "Smallest":
|
186 |
+
analysed_face = get_single_face(analysed_faces, method="smallest")
|
187 |
+
analysed_target_list.append(analysed_face)
|
188 |
+
analysed_source_list.append(analysed_source)
|
189 |
+
whole_frame_eql_list.append(frame_path)
|
190 |
+
n_faces += 1
|
191 |
+
|
192 |
+
num_faces_per_frame.append(n_faces)
|
193 |
+
|
194 |
+
return analysed_target_list, analysed_source_list, whole_frame_eql_list, num_faces_per_frame
|
face_enhancer.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
import gfpgan
|
5 |
+
from PIL import Image
|
6 |
+
from upscaler.RealESRGAN import RealESRGAN
|
7 |
+
from upscaler.codeformer import CodeFormerEnhancer
|
8 |
+
|
9 |
+
def gfpgan_runner(img, model):
|
10 |
+
_, imgs, _ = model.enhance(img, paste_back=True, has_aligned=True)
|
11 |
+
return imgs[0]
|
12 |
+
|
13 |
+
|
14 |
+
def realesrgan_runner(img, model):
|
15 |
+
img = model.predict(img)
|
16 |
+
return img
|
17 |
+
|
18 |
+
|
19 |
+
def codeformer_runner(img, model):
|
20 |
+
img = model.enhance(img)
|
21 |
+
return img
|
22 |
+
|
23 |
+
|
24 |
+
supported_enhancers = {
|
25 |
+
"CodeFormer": ("./assets/pretrained_models/codeformer.onnx", codeformer_runner),
|
26 |
+
"GFPGAN": ("./assets/pretrained_models/GFPGANv1.4.pth", gfpgan_runner),
|
27 |
+
"REAL-ESRGAN 2x": ("./assets/pretrained_models/RealESRGAN_x2.pth", realesrgan_runner),
|
28 |
+
"REAL-ESRGAN 4x": ("./assets/pretrained_models/RealESRGAN_x4.pth", realesrgan_runner),
|
29 |
+
"REAL-ESRGAN 8x": ("./assets/pretrained_models/RealESRGAN_x8.pth", realesrgan_runner)
|
30 |
+
}
|
31 |
+
|
32 |
+
cv2_interpolations = ["LANCZOS4", "CUBIC", "NEAREST"]
|
33 |
+
|
34 |
+
def get_available_enhancer_names():
|
35 |
+
available = []
|
36 |
+
for name, data in supported_enhancers.items():
|
37 |
+
path = os.path.join(os.path.abspath(os.path.dirname(__file__)), data[0])
|
38 |
+
if os.path.exists(path):
|
39 |
+
available.append(name)
|
40 |
+
return available
|
41 |
+
|
42 |
+
|
43 |
+
def load_face_enhancer_model(name='GFPGAN', device="cpu"):
|
44 |
+
assert name in get_available_enhancer_names() + cv2_interpolations, f"Face enhancer {name} unavailable."
|
45 |
+
if name in supported_enhancers.keys():
|
46 |
+
model_path, model_runner = supported_enhancers.get(name)
|
47 |
+
model_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), model_path)
|
48 |
+
if name == 'CodeFormer':
|
49 |
+
model = CodeFormerEnhancer(model_path=model_path, device=device)
|
50 |
+
elif name == 'GFPGAN':
|
51 |
+
model = gfpgan.GFPGANer(model_path=model_path, upscale=1, device=device)
|
52 |
+
elif name == 'REAL-ESRGAN 2x':
|
53 |
+
model = RealESRGAN(device, scale=2)
|
54 |
+
model.load_weights(model_path, download=False)
|
55 |
+
elif name == 'REAL-ESRGAN 4x':
|
56 |
+
model = RealESRGAN(device, scale=4)
|
57 |
+
model.load_weights(model_path, download=False)
|
58 |
+
elif name == 'REAL-ESRGAN 8x':
|
59 |
+
model = RealESRGAN(device, scale=8)
|
60 |
+
model.load_weights(model_path, download=False)
|
61 |
+
elif name == 'LANCZOS4':
|
62 |
+
model = None
|
63 |
+
model_runner = lambda img, _: cv2.resize(img, (512,512), interpolation=cv2.INTER_LANCZOS4)
|
64 |
+
elif name == 'CUBIC':
|
65 |
+
model = None
|
66 |
+
model_runner = lambda img, _: cv2.resize(img, (512,512), interpolation=cv2.INTER_CUBIC)
|
67 |
+
elif name == 'NEAREST':
|
68 |
+
model = None
|
69 |
+
model_runner = lambda img, _: cv2.resize(img, (512,512), interpolation=cv2.INTER_NEAREST)
|
70 |
+
else:
|
71 |
+
model = None
|
72 |
+
return (model, model_runner)
|
face_parsing/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .swap import init_parser, swap_regions, mask_regions, mask_regions_to_list
|
2 |
+
from .model import BiSeNet
|
3 |
+
from .parse_mask import init_parsing_model, get_parsed_mask, SoftErosion
|
face_parsing/model.py
ADDED
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import torchvision
|
9 |
+
|
10 |
+
from .resnet import Resnet18
|
11 |
+
# from modules.bn import InPlaceABNSync as BatchNorm2d
|
12 |
+
|
13 |
+
|
14 |
+
class ConvBNReLU(nn.Module):
|
15 |
+
def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
|
16 |
+
super(ConvBNReLU, self).__init__()
|
17 |
+
self.conv = nn.Conv2d(in_chan,
|
18 |
+
out_chan,
|
19 |
+
kernel_size = ks,
|
20 |
+
stride = stride,
|
21 |
+
padding = padding,
|
22 |
+
bias = False)
|
23 |
+
self.bn = nn.BatchNorm2d(out_chan)
|
24 |
+
self.init_weight()
|
25 |
+
|
26 |
+
def forward(self, x):
|
27 |
+
x = self.conv(x)
|
28 |
+
x = F.relu(self.bn(x))
|
29 |
+
return x
|
30 |
+
|
31 |
+
def init_weight(self):
|
32 |
+
for ly in self.children():
|
33 |
+
if isinstance(ly, nn.Conv2d):
|
34 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
35 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
36 |
+
|
37 |
+
class BiSeNetOutput(nn.Module):
|
38 |
+
def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
|
39 |
+
super(BiSeNetOutput, self).__init__()
|
40 |
+
self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
|
41 |
+
self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
|
42 |
+
self.init_weight()
|
43 |
+
|
44 |
+
def forward(self, x):
|
45 |
+
x = self.conv(x)
|
46 |
+
x = self.conv_out(x)
|
47 |
+
return x
|
48 |
+
|
49 |
+
def init_weight(self):
|
50 |
+
for ly in self.children():
|
51 |
+
if isinstance(ly, nn.Conv2d):
|
52 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
53 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
54 |
+
|
55 |
+
def get_params(self):
|
56 |
+
wd_params, nowd_params = [], []
|
57 |
+
for name, module in self.named_modules():
|
58 |
+
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
59 |
+
wd_params.append(module.weight)
|
60 |
+
if not module.bias is None:
|
61 |
+
nowd_params.append(module.bias)
|
62 |
+
elif isinstance(module, nn.BatchNorm2d):
|
63 |
+
nowd_params += list(module.parameters())
|
64 |
+
return wd_params, nowd_params
|
65 |
+
|
66 |
+
|
67 |
+
class AttentionRefinementModule(nn.Module):
|
68 |
+
def __init__(self, in_chan, out_chan, *args, **kwargs):
|
69 |
+
super(AttentionRefinementModule, self).__init__()
|
70 |
+
self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
|
71 |
+
self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)
|
72 |
+
self.bn_atten = nn.BatchNorm2d(out_chan)
|
73 |
+
self.sigmoid_atten = nn.Sigmoid()
|
74 |
+
self.init_weight()
|
75 |
+
|
76 |
+
def forward(self, x):
|
77 |
+
feat = self.conv(x)
|
78 |
+
atten = F.avg_pool2d(feat, feat.size()[2:])
|
79 |
+
atten = self.conv_atten(atten)
|
80 |
+
atten = self.bn_atten(atten)
|
81 |
+
atten = self.sigmoid_atten(atten)
|
82 |
+
out = torch.mul(feat, atten)
|
83 |
+
return out
|
84 |
+
|
85 |
+
def init_weight(self):
|
86 |
+
for ly in self.children():
|
87 |
+
if isinstance(ly, nn.Conv2d):
|
88 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
89 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
90 |
+
|
91 |
+
|
92 |
+
class ContextPath(nn.Module):
|
93 |
+
def __init__(self, *args, **kwargs):
|
94 |
+
super(ContextPath, self).__init__()
|
95 |
+
self.resnet = Resnet18()
|
96 |
+
self.arm16 = AttentionRefinementModule(256, 128)
|
97 |
+
self.arm32 = AttentionRefinementModule(512, 128)
|
98 |
+
self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
|
99 |
+
self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
|
100 |
+
self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
|
101 |
+
|
102 |
+
self.init_weight()
|
103 |
+
|
104 |
+
def forward(self, x):
|
105 |
+
H0, W0 = x.size()[2:]
|
106 |
+
feat8, feat16, feat32 = self.resnet(x)
|
107 |
+
H8, W8 = feat8.size()[2:]
|
108 |
+
H16, W16 = feat16.size()[2:]
|
109 |
+
H32, W32 = feat32.size()[2:]
|
110 |
+
|
111 |
+
avg = F.avg_pool2d(feat32, feat32.size()[2:])
|
112 |
+
avg = self.conv_avg(avg)
|
113 |
+
avg_up = F.interpolate(avg, (H32, W32), mode='nearest')
|
114 |
+
|
115 |
+
feat32_arm = self.arm32(feat32)
|
116 |
+
feat32_sum = feat32_arm + avg_up
|
117 |
+
feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
|
118 |
+
feat32_up = self.conv_head32(feat32_up)
|
119 |
+
|
120 |
+
feat16_arm = self.arm16(feat16)
|
121 |
+
feat16_sum = feat16_arm + feat32_up
|
122 |
+
feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
|
123 |
+
feat16_up = self.conv_head16(feat16_up)
|
124 |
+
|
125 |
+
return feat8, feat16_up, feat32_up # x8, x8, x16
|
126 |
+
|
127 |
+
def init_weight(self):
|
128 |
+
for ly in self.children():
|
129 |
+
if isinstance(ly, nn.Conv2d):
|
130 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
131 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
132 |
+
|
133 |
+
def get_params(self):
|
134 |
+
wd_params, nowd_params = [], []
|
135 |
+
for name, module in self.named_modules():
|
136 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
137 |
+
wd_params.append(module.weight)
|
138 |
+
if not module.bias is None:
|
139 |
+
nowd_params.append(module.bias)
|
140 |
+
elif isinstance(module, nn.BatchNorm2d):
|
141 |
+
nowd_params += list(module.parameters())
|
142 |
+
return wd_params, nowd_params
|
143 |
+
|
144 |
+
|
145 |
+
### This is not used, since I replace this with the resnet feature with the same size
|
146 |
+
class SpatialPath(nn.Module):
|
147 |
+
def __init__(self, *args, **kwargs):
|
148 |
+
super(SpatialPath, self).__init__()
|
149 |
+
self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)
|
150 |
+
self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
|
151 |
+
self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
|
152 |
+
self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)
|
153 |
+
self.init_weight()
|
154 |
+
|
155 |
+
def forward(self, x):
|
156 |
+
feat = self.conv1(x)
|
157 |
+
feat = self.conv2(feat)
|
158 |
+
feat = self.conv3(feat)
|
159 |
+
feat = self.conv_out(feat)
|
160 |
+
return feat
|
161 |
+
|
162 |
+
def init_weight(self):
|
163 |
+
for ly in self.children():
|
164 |
+
if isinstance(ly, nn.Conv2d):
|
165 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
166 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
167 |
+
|
168 |
+
def get_params(self):
|
169 |
+
wd_params, nowd_params = [], []
|
170 |
+
for name, module in self.named_modules():
|
171 |
+
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
172 |
+
wd_params.append(module.weight)
|
173 |
+
if not module.bias is None:
|
174 |
+
nowd_params.append(module.bias)
|
175 |
+
elif isinstance(module, nn.BatchNorm2d):
|
176 |
+
nowd_params += list(module.parameters())
|
177 |
+
return wd_params, nowd_params
|
178 |
+
|
179 |
+
|
180 |
+
class FeatureFusionModule(nn.Module):
|
181 |
+
def __init__(self, in_chan, out_chan, *args, **kwargs):
|
182 |
+
super(FeatureFusionModule, self).__init__()
|
183 |
+
self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
|
184 |
+
self.conv1 = nn.Conv2d(out_chan,
|
185 |
+
out_chan//4,
|
186 |
+
kernel_size = 1,
|
187 |
+
stride = 1,
|
188 |
+
padding = 0,
|
189 |
+
bias = False)
|
190 |
+
self.conv2 = nn.Conv2d(out_chan//4,
|
191 |
+
out_chan,
|
192 |
+
kernel_size = 1,
|
193 |
+
stride = 1,
|
194 |
+
padding = 0,
|
195 |
+
bias = False)
|
196 |
+
self.relu = nn.ReLU(inplace=True)
|
197 |
+
self.sigmoid = nn.Sigmoid()
|
198 |
+
self.init_weight()
|
199 |
+
|
200 |
+
def forward(self, fsp, fcp):
|
201 |
+
fcat = torch.cat([fsp, fcp], dim=1)
|
202 |
+
feat = self.convblk(fcat)
|
203 |
+
atten = F.avg_pool2d(feat, feat.size()[2:])
|
204 |
+
atten = self.conv1(atten)
|
205 |
+
atten = self.relu(atten)
|
206 |
+
atten = self.conv2(atten)
|
207 |
+
atten = self.sigmoid(atten)
|
208 |
+
feat_atten = torch.mul(feat, atten)
|
209 |
+
feat_out = feat_atten + feat
|
210 |
+
return feat_out
|
211 |
+
|
212 |
+
def init_weight(self):
|
213 |
+
for ly in self.children():
|
214 |
+
if isinstance(ly, nn.Conv2d):
|
215 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
216 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
217 |
+
|
218 |
+
def get_params(self):
|
219 |
+
wd_params, nowd_params = [], []
|
220 |
+
for name, module in self.named_modules():
|
221 |
+
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
222 |
+
wd_params.append(module.weight)
|
223 |
+
if not module.bias is None:
|
224 |
+
nowd_params.append(module.bias)
|
225 |
+
elif isinstance(module, nn.BatchNorm2d):
|
226 |
+
nowd_params += list(module.parameters())
|
227 |
+
return wd_params, nowd_params
|
228 |
+
|
229 |
+
|
230 |
+
class BiSeNet(nn.Module):
|
231 |
+
def __init__(self, n_classes, *args, **kwargs):
|
232 |
+
super(BiSeNet, self).__init__()
|
233 |
+
self.cp = ContextPath()
|
234 |
+
## here self.sp is deleted
|
235 |
+
self.ffm = FeatureFusionModule(256, 256)
|
236 |
+
self.conv_out = BiSeNetOutput(256, 256, n_classes)
|
237 |
+
self.conv_out16 = BiSeNetOutput(128, 64, n_classes)
|
238 |
+
self.conv_out32 = BiSeNetOutput(128, 64, n_classes)
|
239 |
+
self.init_weight()
|
240 |
+
|
241 |
+
def forward(self, x):
|
242 |
+
H, W = x.size()[2:]
|
243 |
+
feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature
|
244 |
+
feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature
|
245 |
+
feat_fuse = self.ffm(feat_sp, feat_cp8)
|
246 |
+
|
247 |
+
feat_out = self.conv_out(feat_fuse)
|
248 |
+
feat_out16 = self.conv_out16(feat_cp8)
|
249 |
+
feat_out32 = self.conv_out32(feat_cp16)
|
250 |
+
|
251 |
+
feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)
|
252 |
+
feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True)
|
253 |
+
feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True)
|
254 |
+
return feat_out, feat_out16, feat_out32
|
255 |
+
|
256 |
+
def init_weight(self):
|
257 |
+
for ly in self.children():
|
258 |
+
if isinstance(ly, nn.Conv2d):
|
259 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
260 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
261 |
+
|
262 |
+
def get_params(self):
|
263 |
+
wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
|
264 |
+
for name, child in self.named_children():
|
265 |
+
child_wd_params, child_nowd_params = child.get_params()
|
266 |
+
if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput):
|
267 |
+
lr_mul_wd_params += child_wd_params
|
268 |
+
lr_mul_nowd_params += child_nowd_params
|
269 |
+
else:
|
270 |
+
wd_params += child_wd_params
|
271 |
+
nowd_params += child_nowd_params
|
272 |
+
return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
|
273 |
+
|
274 |
+
|
275 |
+
if __name__ == "__main__":
|
276 |
+
net = BiSeNet(19)
|
277 |
+
net.cuda()
|
278 |
+
net.eval()
|
279 |
+
in_ten = torch.randn(16, 3, 640, 480).cuda()
|
280 |
+
out, out16, out32 = net(in_ten)
|
281 |
+
print(out.shape)
|
282 |
+
|
283 |
+
net.get_params()
|
face_parsing/parse_mask.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import torch
|
3 |
+
import torchvision
|
4 |
+
import numpy as np
|
5 |
+
import torch.nn as nn
|
6 |
+
from PIL import Image
|
7 |
+
from tqdm import tqdm
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import torchvision.transforms as transforms
|
10 |
+
|
11 |
+
from . model import BiSeNet
|
12 |
+
|
13 |
+
class SoftErosion(nn.Module):
|
14 |
+
def __init__(self, kernel_size=15, threshold=0.6, iterations=1):
|
15 |
+
super(SoftErosion, self).__init__()
|
16 |
+
r = kernel_size // 2
|
17 |
+
self.padding = r
|
18 |
+
self.iterations = iterations
|
19 |
+
self.threshold = threshold
|
20 |
+
|
21 |
+
# Create kernel
|
22 |
+
y_indices, x_indices = torch.meshgrid(torch.arange(0., kernel_size), torch.arange(0., kernel_size))
|
23 |
+
dist = torch.sqrt((x_indices - r) ** 2 + (y_indices - r) ** 2)
|
24 |
+
kernel = dist.max() - dist
|
25 |
+
kernel /= kernel.sum()
|
26 |
+
kernel = kernel.view(1, 1, *kernel.shape)
|
27 |
+
self.register_buffer('weight', kernel)
|
28 |
+
|
29 |
+
def forward(self, x):
|
30 |
+
batch_size = x.size(0) # Get the batch size
|
31 |
+
output = []
|
32 |
+
|
33 |
+
for i in tqdm(range(batch_size), desc="Soft-Erosion", leave=False):
|
34 |
+
input_tensor = x[i:i+1] # Take one input tensor from the batch
|
35 |
+
input_tensor = input_tensor.float() # Convert input to float tensor
|
36 |
+
input_tensor = input_tensor.unsqueeze(1) # Add a channel dimension
|
37 |
+
|
38 |
+
for _ in range(self.iterations - 1):
|
39 |
+
input_tensor = torch.min(input_tensor, F.conv2d(input_tensor, weight=self.weight,
|
40 |
+
groups=input_tensor.shape[1],
|
41 |
+
padding=self.padding))
|
42 |
+
input_tensor = F.conv2d(input_tensor, weight=self.weight, groups=input_tensor.shape[1],
|
43 |
+
padding=self.padding)
|
44 |
+
|
45 |
+
mask = input_tensor >= self.threshold
|
46 |
+
input_tensor[mask] = 1.0
|
47 |
+
input_tensor[~mask] /= input_tensor[~mask].max()
|
48 |
+
|
49 |
+
input_tensor = input_tensor.squeeze(1) # Remove the extra channel dimension
|
50 |
+
output.append(input_tensor.detach().cpu().numpy())
|
51 |
+
|
52 |
+
return np.array(output)
|
53 |
+
|
54 |
+
transform = transforms.Compose([
|
55 |
+
transforms.Resize((512, 512)),
|
56 |
+
transforms.ToTensor(),
|
57 |
+
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
|
58 |
+
])
|
59 |
+
|
60 |
+
|
61 |
+
|
62 |
+
def init_parsing_model(model_path, device="cpu"):
|
63 |
+
net = BiSeNet(19)
|
64 |
+
net.to(device)
|
65 |
+
net.load_state_dict(torch.load(model_path))
|
66 |
+
net.eval()
|
67 |
+
return net
|
68 |
+
|
69 |
+
def transform_images(imgs):
|
70 |
+
tensor_images = torch.stack([transform(Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))) for img in imgs], dim=0)
|
71 |
+
return tensor_images
|
72 |
+
|
73 |
+
def get_parsed_mask(net, imgs, classes=[1, 2, 3, 4, 5, 10, 11, 12, 13], device="cpu", batch_size=8, softness=20):
|
74 |
+
if softness > 0:
|
75 |
+
smooth_mask = SoftErosion(kernel_size=17, threshold=0.9, iterations=softness).to(device)
|
76 |
+
|
77 |
+
masks = []
|
78 |
+
for i in tqdm(range(0, len(imgs), batch_size), total=len(imgs) // batch_size, desc="Face-parsing"):
|
79 |
+
batch_imgs = imgs[i:i + batch_size]
|
80 |
+
|
81 |
+
tensor_images = transform_images(batch_imgs).to(device)
|
82 |
+
with torch.no_grad():
|
83 |
+
out = net(tensor_images)[0]
|
84 |
+
# parsing = out.argmax(dim=1)
|
85 |
+
# arget_classes = torch.tensor(classes).to(device)
|
86 |
+
# batch_masks = torch.isin(parsing, target_classes).to(device)
|
87 |
+
## torch.isin was slightly slower in my test, so using np.isin
|
88 |
+
parsing = out.argmax(dim=1).detach().cpu().numpy()
|
89 |
+
batch_masks = np.isin(parsing, classes).astype('float32')
|
90 |
+
|
91 |
+
if softness > 0:
|
92 |
+
# batch_masks = smooth_mask(batch_masks).transpose(1,0,2,3)[0]
|
93 |
+
mask_tensor = torch.from_numpy(batch_masks.copy()).float().to(device)
|
94 |
+
batch_masks = smooth_mask(mask_tensor).transpose(1,0,2,3)[0]
|
95 |
+
|
96 |
+
yield batch_masks
|
97 |
+
|
98 |
+
#masks.append(batch_masks)
|
99 |
+
|
100 |
+
#if len(masks) >= 1:
|
101 |
+
# masks = np.concatenate(masks, axis=0)
|
102 |
+
# masks = np.repeat(np.expand_dims(masks, axis=1), 3, axis=1)
|
103 |
+
|
104 |
+
# for i, mask in enumerate(masks):
|
105 |
+
# cv2.imwrite(f"mask/{i}.jpg", (mask * 255).astype("uint8"))
|
106 |
+
|
107 |
+
#return masks
|
face_parsing/resnet.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import torch.utils.model_zoo as modelzoo
|
8 |
+
|
9 |
+
# from modules.bn import InPlaceABNSync as BatchNorm2d
|
10 |
+
|
11 |
+
resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
|
12 |
+
|
13 |
+
|
14 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
15 |
+
"""3x3 convolution with padding"""
|
16 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
17 |
+
padding=1, bias=False)
|
18 |
+
|
19 |
+
|
20 |
+
class BasicBlock(nn.Module):
|
21 |
+
def __init__(self, in_chan, out_chan, stride=1):
|
22 |
+
super(BasicBlock, self).__init__()
|
23 |
+
self.conv1 = conv3x3(in_chan, out_chan, stride)
|
24 |
+
self.bn1 = nn.BatchNorm2d(out_chan)
|
25 |
+
self.conv2 = conv3x3(out_chan, out_chan)
|
26 |
+
self.bn2 = nn.BatchNorm2d(out_chan)
|
27 |
+
self.relu = nn.ReLU(inplace=True)
|
28 |
+
self.downsample = None
|
29 |
+
if in_chan != out_chan or stride != 1:
|
30 |
+
self.downsample = nn.Sequential(
|
31 |
+
nn.Conv2d(in_chan, out_chan,
|
32 |
+
kernel_size=1, stride=stride, bias=False),
|
33 |
+
nn.BatchNorm2d(out_chan),
|
34 |
+
)
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
residual = self.conv1(x)
|
38 |
+
residual = F.relu(self.bn1(residual))
|
39 |
+
residual = self.conv2(residual)
|
40 |
+
residual = self.bn2(residual)
|
41 |
+
|
42 |
+
shortcut = x
|
43 |
+
if self.downsample is not None:
|
44 |
+
shortcut = self.downsample(x)
|
45 |
+
|
46 |
+
out = shortcut + residual
|
47 |
+
out = self.relu(out)
|
48 |
+
return out
|
49 |
+
|
50 |
+
|
51 |
+
def create_layer_basic(in_chan, out_chan, bnum, stride=1):
|
52 |
+
layers = [BasicBlock(in_chan, out_chan, stride=stride)]
|
53 |
+
for i in range(bnum-1):
|
54 |
+
layers.append(BasicBlock(out_chan, out_chan, stride=1))
|
55 |
+
return nn.Sequential(*layers)
|
56 |
+
|
57 |
+
|
58 |
+
class Resnet18(nn.Module):
|
59 |
+
def __init__(self):
|
60 |
+
super(Resnet18, self).__init__()
|
61 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
|
62 |
+
bias=False)
|
63 |
+
self.bn1 = nn.BatchNorm2d(64)
|
64 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
65 |
+
self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
|
66 |
+
self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
|
67 |
+
self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
|
68 |
+
self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
|
69 |
+
self.init_weight()
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
x = self.conv1(x)
|
73 |
+
x = F.relu(self.bn1(x))
|
74 |
+
x = self.maxpool(x)
|
75 |
+
|
76 |
+
x = self.layer1(x)
|
77 |
+
feat8 = self.layer2(x) # 1/8
|
78 |
+
feat16 = self.layer3(feat8) # 1/16
|
79 |
+
feat32 = self.layer4(feat16) # 1/32
|
80 |
+
return feat8, feat16, feat32
|
81 |
+
|
82 |
+
def init_weight(self):
|
83 |
+
state_dict = modelzoo.load_url(resnet18_url)
|
84 |
+
self_state_dict = self.state_dict()
|
85 |
+
for k, v in state_dict.items():
|
86 |
+
if 'fc' in k: continue
|
87 |
+
self_state_dict.update({k: v})
|
88 |
+
self.load_state_dict(self_state_dict)
|
89 |
+
|
90 |
+
def get_params(self):
|
91 |
+
wd_params, nowd_params = [], []
|
92 |
+
for name, module in self.named_modules():
|
93 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
94 |
+
wd_params.append(module.weight)
|
95 |
+
if not module.bias is None:
|
96 |
+
nowd_params.append(module.bias)
|
97 |
+
elif isinstance(module, nn.BatchNorm2d):
|
98 |
+
nowd_params += list(module.parameters())
|
99 |
+
return wd_params, nowd_params
|
100 |
+
|
101 |
+
|
102 |
+
if __name__ == "__main__":
|
103 |
+
net = Resnet18()
|
104 |
+
x = torch.randn(16, 3, 224, 224)
|
105 |
+
out = net(x)
|
106 |
+
print(out[0].size())
|
107 |
+
print(out[1].size())
|
108 |
+
print(out[2].size())
|
109 |
+
net.get_params()
|
face_parsing/swap.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torchvision.transforms as transforms
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from .model import BiSeNet
|
9 |
+
|
10 |
+
mask_regions = {
|
11 |
+
"Background":0,
|
12 |
+
"Skin":1,
|
13 |
+
"L-Eyebrow":2,
|
14 |
+
"R-Eyebrow":3,
|
15 |
+
"L-Eye":4,
|
16 |
+
"R-Eye":5,
|
17 |
+
"Eye-G":6,
|
18 |
+
"L-Ear":7,
|
19 |
+
"R-Ear":8,
|
20 |
+
"Ear-R":9,
|
21 |
+
"Nose":10,
|
22 |
+
"Mouth":11,
|
23 |
+
"U-Lip":12,
|
24 |
+
"L-Lip":13,
|
25 |
+
"Neck":14,
|
26 |
+
"Neck-L":15,
|
27 |
+
"Cloth":16,
|
28 |
+
"Hair":17,
|
29 |
+
"Hat":18
|
30 |
+
}
|
31 |
+
|
32 |
+
# Borrowed from simswap
|
33 |
+
# https://github.com/neuralchen/SimSwap/blob/26c84d2901bd56eda4d5e3c5ca6da16e65dc82a6/util/reverse2original.py#L30
|
34 |
+
class SoftErosion(nn.Module):
|
35 |
+
def __init__(self, kernel_size=15, threshold=0.6, iterations=1):
|
36 |
+
super(SoftErosion, self).__init__()
|
37 |
+
r = kernel_size // 2
|
38 |
+
self.padding = r
|
39 |
+
self.iterations = iterations
|
40 |
+
self.threshold = threshold
|
41 |
+
|
42 |
+
# Create kernel
|
43 |
+
y_indices, x_indices = torch.meshgrid(torch.arange(0., kernel_size), torch.arange(0., kernel_size))
|
44 |
+
dist = torch.sqrt((x_indices - r) ** 2 + (y_indices - r) ** 2)
|
45 |
+
kernel = dist.max() - dist
|
46 |
+
kernel /= kernel.sum()
|
47 |
+
kernel = kernel.view(1, 1, *kernel.shape)
|
48 |
+
self.register_buffer('weight', kernel)
|
49 |
+
|
50 |
+
def forward(self, x):
|
51 |
+
x = x.float()
|
52 |
+
for i in range(self.iterations - 1):
|
53 |
+
x = torch.min(x, F.conv2d(x, weight=self.weight, groups=x.shape[1], padding=self.padding))
|
54 |
+
x = F.conv2d(x, weight=self.weight, groups=x.shape[1], padding=self.padding)
|
55 |
+
|
56 |
+
mask = x >= self.threshold
|
57 |
+
x[mask] = 1.0
|
58 |
+
x[~mask] /= x[~mask].max()
|
59 |
+
|
60 |
+
return x, mask
|
61 |
+
|
62 |
+
device = "cpu"
|
63 |
+
|
64 |
+
def init_parser(pth_path, mode="cpu"):
|
65 |
+
global device
|
66 |
+
device = mode
|
67 |
+
n_classes = 19
|
68 |
+
net = BiSeNet(n_classes=n_classes)
|
69 |
+
if device == "cuda":
|
70 |
+
net.cuda()
|
71 |
+
net.load_state_dict(torch.load(pth_path))
|
72 |
+
else:
|
73 |
+
net.load_state_dict(torch.load(pth_path, map_location=torch.device('cpu')))
|
74 |
+
net.eval()
|
75 |
+
return net
|
76 |
+
|
77 |
+
|
78 |
+
def image_to_parsing(img, net):
|
79 |
+
img = cv2.resize(img, (512, 512))
|
80 |
+
img = img[:,:,::-1]
|
81 |
+
transform = transforms.Compose([
|
82 |
+
transforms.ToTensor(),
|
83 |
+
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
|
84 |
+
])
|
85 |
+
img = transform(img.copy())
|
86 |
+
img = torch.unsqueeze(img, 0)
|
87 |
+
|
88 |
+
with torch.no_grad():
|
89 |
+
img = img.to(device)
|
90 |
+
out = net(img)[0]
|
91 |
+
parsing = out.squeeze(0).cpu().numpy().argmax(0)
|
92 |
+
return parsing
|
93 |
+
|
94 |
+
|
95 |
+
def get_mask(parsing, classes):
|
96 |
+
res = parsing == classes[0]
|
97 |
+
for val in classes[1:]:
|
98 |
+
res += parsing == val
|
99 |
+
return res
|
100 |
+
|
101 |
+
|
102 |
+
def swap_regions(source, target, net, smooth_mask, includes=[1,2,3,4,5,10,11,12,13], blur=10):
|
103 |
+
parsing = image_to_parsing(source, net)
|
104 |
+
|
105 |
+
if len(includes) == 0:
|
106 |
+
return source, np.zeros_like(source)
|
107 |
+
|
108 |
+
include_mask = get_mask(parsing, includes)
|
109 |
+
mask = np.repeat(include_mask[:, :, np.newaxis], 3, axis=2).astype("float32")
|
110 |
+
|
111 |
+
if smooth_mask is not None:
|
112 |
+
mask_tensor = torch.from_numpy(mask.copy().transpose((2, 0, 1))).float().to(device)
|
113 |
+
face_mask_tensor = mask_tensor[0] + mask_tensor[1]
|
114 |
+
soft_face_mask_tensor, _ = smooth_mask(face_mask_tensor.unsqueeze_(0).unsqueeze_(0))
|
115 |
+
soft_face_mask_tensor.squeeze_()
|
116 |
+
mask = np.repeat(soft_face_mask_tensor.cpu().numpy()[:, :, np.newaxis], 3, axis=2)
|
117 |
+
|
118 |
+
if blur > 0:
|
119 |
+
mask = cv2.GaussianBlur(mask, (0, 0), blur)
|
120 |
+
|
121 |
+
resized_source = cv2.resize((source).astype("float32"), (512, 512))
|
122 |
+
resized_target = cv2.resize((target).astype("float32"), (512, 512))
|
123 |
+
result = mask * resized_source + (1 - mask) * resized_target
|
124 |
+
result = cv2.resize(result.astype("uint8"), (source.shape[1], source.shape[0]))
|
125 |
+
|
126 |
+
return result
|
127 |
+
|
128 |
+
def mask_regions_to_list(values):
|
129 |
+
out_ids = []
|
130 |
+
for value in values:
|
131 |
+
if value in mask_regions.keys():
|
132 |
+
out_ids.append(mask_regions.get(value))
|
133 |
+
return out_ids
|
face_swapper.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import torch
|
3 |
+
import onnx
|
4 |
+
import cv2
|
5 |
+
import onnxruntime
|
6 |
+
import numpy as np
|
7 |
+
from tqdm import tqdm
|
8 |
+
import torch.nn as nn
|
9 |
+
from onnx import numpy_helper
|
10 |
+
from skimage import transform as trans
|
11 |
+
import torchvision.transforms.functional as F
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from utils import mask_crop, laplacian_blending
|
14 |
+
|
15 |
+
|
16 |
+
arcface_dst = np.array(
|
17 |
+
[[38.2946, 51.6963], [73.5318, 51.5014], [56.0252, 71.7366],
|
18 |
+
[41.5493, 92.3655], [70.7299, 92.2041]],
|
19 |
+
dtype=np.float32)
|
20 |
+
|
21 |
+
def estimate_norm(lmk, image_size=112, mode='arcface'):
|
22 |
+
assert lmk.shape == (5, 2)
|
23 |
+
assert image_size % 112 == 0 or image_size % 128 == 0
|
24 |
+
if image_size % 112 == 0:
|
25 |
+
ratio = float(image_size) / 112.0
|
26 |
+
diff_x = 0
|
27 |
+
else:
|
28 |
+
ratio = float(image_size) / 128.0
|
29 |
+
diff_x = 8.0 * ratio
|
30 |
+
dst = arcface_dst * ratio
|
31 |
+
dst[:, 0] += diff_x
|
32 |
+
tform = trans.SimilarityTransform()
|
33 |
+
tform.estimate(lmk, dst)
|
34 |
+
M = tform.params[0:2, :]
|
35 |
+
return M
|
36 |
+
|
37 |
+
|
38 |
+
def norm_crop2(img, landmark, image_size=112, mode='arcface'):
|
39 |
+
M = estimate_norm(landmark, image_size, mode)
|
40 |
+
warped = cv2.warpAffine(img, M, (image_size, image_size), borderValue=0.0)
|
41 |
+
return warped, M
|
42 |
+
|
43 |
+
|
44 |
+
class Inswapper():
|
45 |
+
def __init__(self, model_file=None, batch_size=32, providers=['CPUExecutionProvider']):
|
46 |
+
self.model_file = model_file
|
47 |
+
self.batch_size = batch_size
|
48 |
+
|
49 |
+
model = onnx.load(self.model_file)
|
50 |
+
graph = model.graph
|
51 |
+
self.emap = numpy_helper.to_array(graph.initializer[-1])
|
52 |
+
|
53 |
+
self.session_options = onnxruntime.SessionOptions()
|
54 |
+
self.session = onnxruntime.InferenceSession(self.model_file, sess_options=self.session_options, providers=providers)
|
55 |
+
|
56 |
+
def forward(self, imgs, latents):
|
57 |
+
preds = []
|
58 |
+
for img, latent in zip(imgs, latents):
|
59 |
+
img = img / 255
|
60 |
+
pred = self.session.run(['output'], {'target': img, 'source': latent})[0]
|
61 |
+
preds.append(pred)
|
62 |
+
|
63 |
+
def get(self, imgs, target_faces, source_faces):
|
64 |
+
imgs = list(imgs)
|
65 |
+
|
66 |
+
preds = [None] * len(imgs)
|
67 |
+
matrs = [None] * len(imgs)
|
68 |
+
|
69 |
+
for idx, (img, target_face, source_face) in enumerate(zip(imgs, target_faces, source_faces)):
|
70 |
+
matrix, blob, latent = self.prepare_data(img, target_face, source_face)
|
71 |
+
pred = self.session.run(['output'], {'target': blob, 'source': latent})[0]
|
72 |
+
pred = pred.transpose((0, 2, 3, 1))[0]
|
73 |
+
pred = np.clip(255 * pred, 0, 255).astype(np.uint8)[:, :, ::-1]
|
74 |
+
|
75 |
+
preds[idx] = pred
|
76 |
+
matrs[idx] = matrix
|
77 |
+
|
78 |
+
return (preds, matrs)
|
79 |
+
|
80 |
+
def prepare_data(self, img, target_face, source_face):
|
81 |
+
if isinstance(img, str):
|
82 |
+
img = cv2.imread(img)
|
83 |
+
|
84 |
+
aligned_img, matrix = norm_crop2(img, target_face.kps, 128)
|
85 |
+
|
86 |
+
blob = cv2.dnn.blobFromImage(aligned_img, 1.0 / 255, (128, 128), (0., 0., 0.), swapRB=True)
|
87 |
+
|
88 |
+
latent = source_face.normed_embedding.reshape((1, -1))
|
89 |
+
latent = np.dot(latent, self.emap)
|
90 |
+
latent /= np.linalg.norm(latent)
|
91 |
+
|
92 |
+
return (matrix, blob, latent)
|
93 |
+
|
94 |
+
def batch_forward(self, img_list, target_f_list, source_f_list):
|
95 |
+
num_samples = len(img_list)
|
96 |
+
num_batches = (num_samples + self.batch_size - 1) // self.batch_size
|
97 |
+
|
98 |
+
for i in tqdm(range(num_batches), desc="Generating face"):
|
99 |
+
start_idx = i * self.batch_size
|
100 |
+
end_idx = min((i + 1) * self.batch_size, num_samples)
|
101 |
+
|
102 |
+
batch_img = img_list[start_idx:end_idx]
|
103 |
+
batch_target_f = target_f_list[start_idx:end_idx]
|
104 |
+
batch_source_f = source_f_list[start_idx:end_idx]
|
105 |
+
|
106 |
+
batch_pred, batch_matr = self.get(batch_img, batch_target_f, batch_source_f)
|
107 |
+
|
108 |
+
yield batch_pred, batch_matr
|
109 |
+
|
110 |
+
|
111 |
+
def paste_to_whole(foreground, background, matrix, mask=None, crop_mask=(0,0,0,0), blur_amount=0.1, erode_amount = 0.15, blend_method='linear'):
|
112 |
+
inv_matrix = cv2.invertAffineTransform(matrix)
|
113 |
+
fg_shape = foreground.shape[:2]
|
114 |
+
bg_shape = (background.shape[1], background.shape[0])
|
115 |
+
foreground = cv2.warpAffine(foreground, inv_matrix, bg_shape, borderValue=0.0)
|
116 |
+
|
117 |
+
if mask is None:
|
118 |
+
mask = np.full(fg_shape, 1., dtype=np.float32)
|
119 |
+
mask = mask_crop(mask, crop_mask)
|
120 |
+
mask = cv2.warpAffine(mask, inv_matrix, bg_shape, borderValue=0.0)
|
121 |
+
else:
|
122 |
+
assert fg_shape == mask.shape[:2], "foreground & mask shape mismatch!"
|
123 |
+
mask = mask_crop(mask, crop_mask).astype('float32')
|
124 |
+
mask = cv2.warpAffine(mask, inv_matrix, (background.shape[1], background.shape[0]), borderValue=0.0)
|
125 |
+
|
126 |
+
_mask = mask.copy()
|
127 |
+
_mask[_mask > 0.05] = 1.
|
128 |
+
non_zero_points = cv2.findNonZero(_mask)
|
129 |
+
_, _, w, h = cv2.boundingRect(non_zero_points)
|
130 |
+
mask_size = int(np.sqrt(w * h))
|
131 |
+
|
132 |
+
if erode_amount > 0:
|
133 |
+
kernel_size = max(int(mask_size * erode_amount), 1)
|
134 |
+
structuring_element = cv2.getStructuringElement(cv2.MORPH_RECT, (kernel_size, kernel_size))
|
135 |
+
mask = cv2.erode(mask, structuring_element)
|
136 |
+
|
137 |
+
if blur_amount > 0:
|
138 |
+
kernel_size = max(int(mask_size * blur_amount), 3)
|
139 |
+
if kernel_size % 2 == 0:
|
140 |
+
kernel_size += 1
|
141 |
+
mask = cv2.GaussianBlur(mask, (kernel_size, kernel_size), 0)
|
142 |
+
|
143 |
+
mask = np.tile(np.expand_dims(mask, axis=-1), (1, 1, 3))
|
144 |
+
|
145 |
+
if blend_method == 'laplacian':
|
146 |
+
composite_image = laplacian_blending(foreground, background, mask.clip(0,1), num_levels=4)
|
147 |
+
else:
|
148 |
+
composite_image = mask * foreground + (1 - mask) * background
|
149 |
+
|
150 |
+
return composite_image.astype("uint8").clip(0, 255)
|
gfpgan/weights/detection_Resnet50_Final.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6d1de9c2944f2ccddca5f5e010ea5ae64a39845a86311af6fdf30841b0a5a16d
|
3 |
+
size 109497761
|
gfpgan/weights/parsing_parsenet.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3d558d8d0e42c20224f13cf5a29c79eba2d59913419f945545d8cf7b72920de2
|
3 |
+
size 85331193
|
nsfw_checker/LICENSE.md
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
Copyright 2016, Yahoo Inc.
|
3 |
+
|
4 |
+
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
5 |
+
|
6 |
+
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
7 |
+
|
8 |
+
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
9 |
+
|
10 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
11 |
+
|
nsfw_checker/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from . opennsfw import NSFWChecker
|
nsfw_checker/opennsfw.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import torch
|
3 |
+
import onnx
|
4 |
+
import onnxruntime
|
5 |
+
import numpy as np
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
+
# https://github.com/yahoo/open_nsfw
|
9 |
+
|
10 |
+
class NSFWChecker:
|
11 |
+
def __init__(self, model_path=None, providers=["CPUExecutionProvider"]):
|
12 |
+
model = onnx.load(model_path)
|
13 |
+
self.input_name = model.graph.input[0].name
|
14 |
+
session_options = onnxruntime.SessionOptions()
|
15 |
+
self.session = onnxruntime.InferenceSession(model_path, sess_options=session_options, providers=providers)
|
16 |
+
|
17 |
+
def is_nsfw(self, img_paths, threshold = 0.85):
|
18 |
+
skip_step = 1
|
19 |
+
total_len = len(img_paths)
|
20 |
+
if total_len < 100: skip_step = 1
|
21 |
+
if total_len > 100 and total_len < 500: skip_step = 10
|
22 |
+
if total_len > 500 and total_len < 1000: skip_step = 20
|
23 |
+
if total_len > 1000 and total_len < 10000: skip_step = 50
|
24 |
+
if total_len > 10000: skip_step = 100
|
25 |
+
|
26 |
+
for idx in tqdm(range(0, total_len, skip_step), total=int(total_len // skip_step), desc="Checking for NSFW contents"):
|
27 |
+
img = cv2.imread(img_paths[idx])
|
28 |
+
img = cv2.resize(img, (224,224)).astype('float32')
|
29 |
+
img -= np.array([104, 117, 123], dtype=np.float32)
|
30 |
+
img = np.expand_dims(img, axis=0)
|
31 |
+
|
32 |
+
score = self.session.run(None, {self.input_name:img})[0][0][1]
|
33 |
+
|
34 |
+
if score > threshold:
|
35 |
+
print(f"Detected nsfw score:{score}")
|
36 |
+
return True
|
37 |
+
return False
|
nsfw_detector.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torchvision.transforms import Normalize
|
2 |
+
import torchvision.transforms as T
|
3 |
+
import torch.nn as nn
|
4 |
+
from PIL import Image
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import timm
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
# https://github.com/Whiax/NSFW-Classifier/raw/main/nsfwmodel_281.pth
|
11 |
+
normalize_t = Normalize((0.4814, 0.4578, 0.4082), (0.2686, 0.2613, 0.2757))
|
12 |
+
|
13 |
+
#nsfw classifier
|
14 |
+
class NSFWClassifier(nn.Module):
|
15 |
+
def __init__(self):
|
16 |
+
super().__init__()
|
17 |
+
nsfw_model=self
|
18 |
+
nsfw_model.root_model = timm.create_model('convnext_base_in22ft1k', pretrained=True)
|
19 |
+
nsfw_model.linear_probe = nn.Linear(1024, 1, bias=False)
|
20 |
+
|
21 |
+
def forward(self, x):
|
22 |
+
nsfw_model = self
|
23 |
+
x = normalize_t(x)
|
24 |
+
x = nsfw_model.root_model.stem(x)
|
25 |
+
x = nsfw_model.root_model.stages(x)
|
26 |
+
x = nsfw_model.root_model.head.global_pool(x)
|
27 |
+
x = nsfw_model.root_model.head.norm(x)
|
28 |
+
x = nsfw_model.root_model.head.flatten(x)
|
29 |
+
x = nsfw_model.linear_probe(x)
|
30 |
+
return x
|
31 |
+
|
32 |
+
def is_nsfw(self, img_paths, threshold = 0.98):
|
33 |
+
skip_step = 1
|
34 |
+
total_len = len(img_paths)
|
35 |
+
if total_len < 100: skip_step = 1
|
36 |
+
if total_len > 100 and total_len < 500: skip_step = 10
|
37 |
+
if total_len > 500 and total_len < 1000: skip_step = 20
|
38 |
+
if total_len > 1000 and total_len < 10000: skip_step = 50
|
39 |
+
if total_len > 10000: skip_step = 100
|
40 |
+
|
41 |
+
for idx in tqdm(range(0, total_len, skip_step), total=int(total_len // skip_step), desc="Checking for NSFW contents"):
|
42 |
+
_img = Image.open(img_paths[idx]).convert('RGB')
|
43 |
+
img = _img.resize((224, 224))
|
44 |
+
img = np.array(img)/255
|
45 |
+
img = T.ToTensor()(img).unsqueeze(0).float()
|
46 |
+
if next(self.parameters()).is_cuda:
|
47 |
+
img = img.cuda()
|
48 |
+
with torch.no_grad():
|
49 |
+
score = self.forward(img).sigmoid()[0].item()
|
50 |
+
if score > threshold:
|
51 |
+
print(f"Detected nsfw score:{score}")
|
52 |
+
_img.save("nsfw.jpg")
|
53 |
+
return True
|
54 |
+
return False
|
55 |
+
|
56 |
+
def get_nsfw_detector(model_path='nsfwmodel_281.pth', device="cpu"):
|
57 |
+
#load base model
|
58 |
+
nsfw_model = NSFWClassifier()
|
59 |
+
nsfw_model = nsfw_model.eval()
|
60 |
+
#load linear weights
|
61 |
+
linear_pth = model_path
|
62 |
+
linear_state_dict = torch.load(linear_pth, map_location='cpu')
|
63 |
+
nsfw_model.linear_probe.load_state_dict(linear_state_dict)
|
64 |
+
nsfw_model = nsfw_model.to(device)
|
65 |
+
return nsfw_model
|
requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.0.1
|
2 |
+
torchvision
|
3 |
+
gradio>=3.33.1
|
4 |
+
insightface==0.7.3
|
5 |
+
moviepy>=1.0.3
|
6 |
+
numpy
|
7 |
+
onnx==1.14.0
|
8 |
+
onnxruntime==1.15.0
|
9 |
+
opencv-python>=4.7.0.72
|
10 |
+
opencv-python-headless>=4.7.0.72
|
11 |
+
gfpgan==1.3.8
|
12 |
+
|
upscaler/RealESRGAN/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .model import RealESRGAN
|
upscaler/RealESRGAN/arch_utils.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from torch import nn as nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
from torch.nn import init as init
|
6 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
7 |
+
|
8 |
+
@torch.no_grad()
|
9 |
+
def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
|
10 |
+
"""Initialize network weights.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
module_list (list[nn.Module] | nn.Module): Modules to be initialized.
|
14 |
+
scale (float): Scale initialized weights, especially for residual
|
15 |
+
blocks. Default: 1.
|
16 |
+
bias_fill (float): The value to fill bias. Default: 0
|
17 |
+
kwargs (dict): Other arguments for initialization function.
|
18 |
+
"""
|
19 |
+
if not isinstance(module_list, list):
|
20 |
+
module_list = [module_list]
|
21 |
+
for module in module_list:
|
22 |
+
for m in module.modules():
|
23 |
+
if isinstance(m, nn.Conv2d):
|
24 |
+
init.kaiming_normal_(m.weight, **kwargs)
|
25 |
+
m.weight.data *= scale
|
26 |
+
if m.bias is not None:
|
27 |
+
m.bias.data.fill_(bias_fill)
|
28 |
+
elif isinstance(m, nn.Linear):
|
29 |
+
init.kaiming_normal_(m.weight, **kwargs)
|
30 |
+
m.weight.data *= scale
|
31 |
+
if m.bias is not None:
|
32 |
+
m.bias.data.fill_(bias_fill)
|
33 |
+
elif isinstance(m, _BatchNorm):
|
34 |
+
init.constant_(m.weight, 1)
|
35 |
+
if m.bias is not None:
|
36 |
+
m.bias.data.fill_(bias_fill)
|
37 |
+
|
38 |
+
|
39 |
+
def make_layer(basic_block, num_basic_block, **kwarg):
|
40 |
+
"""Make layers by stacking the same blocks.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
basic_block (nn.module): nn.module class for basic block.
|
44 |
+
num_basic_block (int): number of blocks.
|
45 |
+
|
46 |
+
Returns:
|
47 |
+
nn.Sequential: Stacked blocks in nn.Sequential.
|
48 |
+
"""
|
49 |
+
layers = []
|
50 |
+
for _ in range(num_basic_block):
|
51 |
+
layers.append(basic_block(**kwarg))
|
52 |
+
return nn.Sequential(*layers)
|
53 |
+
|
54 |
+
|
55 |
+
class ResidualBlockNoBN(nn.Module):
|
56 |
+
"""Residual block without BN.
|
57 |
+
|
58 |
+
It has a style of:
|
59 |
+
---Conv-ReLU-Conv-+-
|
60 |
+
|________________|
|
61 |
+
|
62 |
+
Args:
|
63 |
+
num_feat (int): Channel number of intermediate features.
|
64 |
+
Default: 64.
|
65 |
+
res_scale (float): Residual scale. Default: 1.
|
66 |
+
pytorch_init (bool): If set to True, use pytorch default init,
|
67 |
+
otherwise, use default_init_weights. Default: False.
|
68 |
+
"""
|
69 |
+
|
70 |
+
def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
|
71 |
+
super(ResidualBlockNoBN, self).__init__()
|
72 |
+
self.res_scale = res_scale
|
73 |
+
self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
74 |
+
self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
75 |
+
self.relu = nn.ReLU(inplace=True)
|
76 |
+
|
77 |
+
if not pytorch_init:
|
78 |
+
default_init_weights([self.conv1, self.conv2], 0.1)
|
79 |
+
|
80 |
+
def forward(self, x):
|
81 |
+
identity = x
|
82 |
+
out = self.conv2(self.relu(self.conv1(x)))
|
83 |
+
return identity + out * self.res_scale
|
84 |
+
|
85 |
+
|
86 |
+
class Upsample(nn.Sequential):
|
87 |
+
"""Upsample module.
|
88 |
+
|
89 |
+
Args:
|
90 |
+
scale (int): Scale factor. Supported scales: 2^n and 3.
|
91 |
+
num_feat (int): Channel number of intermediate features.
|
92 |
+
"""
|
93 |
+
|
94 |
+
def __init__(self, scale, num_feat):
|
95 |
+
m = []
|
96 |
+
if (scale & (scale - 1)) == 0: # scale = 2^n
|
97 |
+
for _ in range(int(math.log(scale, 2))):
|
98 |
+
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
|
99 |
+
m.append(nn.PixelShuffle(2))
|
100 |
+
elif scale == 3:
|
101 |
+
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
|
102 |
+
m.append(nn.PixelShuffle(3))
|
103 |
+
else:
|
104 |
+
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
|
105 |
+
super(Upsample, self).__init__(*m)
|
106 |
+
|
107 |
+
|
108 |
+
def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True):
|
109 |
+
"""Warp an image or feature map with optical flow.
|
110 |
+
|
111 |
+
Args:
|
112 |
+
x (Tensor): Tensor with size (n, c, h, w).
|
113 |
+
flow (Tensor): Tensor with size (n, h, w, 2), normal value.
|
114 |
+
interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
|
115 |
+
padding_mode (str): 'zeros' or 'border' or 'reflection'.
|
116 |
+
Default: 'zeros'.
|
117 |
+
align_corners (bool): Before pytorch 1.3, the default value is
|
118 |
+
align_corners=True. After pytorch 1.3, the default value is
|
119 |
+
align_corners=False. Here, we use the True as default.
|
120 |
+
|
121 |
+
Returns:
|
122 |
+
Tensor: Warped image or feature map.
|
123 |
+
"""
|
124 |
+
assert x.size()[-2:] == flow.size()[1:3]
|
125 |
+
_, _, h, w = x.size()
|
126 |
+
# create mesh grid
|
127 |
+
grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x))
|
128 |
+
grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
|
129 |
+
grid.requires_grad = False
|
130 |
+
|
131 |
+
vgrid = grid + flow
|
132 |
+
# scale grid to [-1,1]
|
133 |
+
vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
|
134 |
+
vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
|
135 |
+
vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
|
136 |
+
output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners)
|
137 |
+
|
138 |
+
# TODO, what if align_corners=False
|
139 |
+
return output
|
140 |
+
|
141 |
+
|
142 |
+
def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False):
|
143 |
+
"""Resize a flow according to ratio or shape.
|
144 |
+
|
145 |
+
Args:
|
146 |
+
flow (Tensor): Precomputed flow. shape [N, 2, H, W].
|
147 |
+
size_type (str): 'ratio' or 'shape'.
|
148 |
+
sizes (list[int | float]): the ratio for resizing or the final output
|
149 |
+
shape.
|
150 |
+
1) The order of ratio should be [ratio_h, ratio_w]. For
|
151 |
+
downsampling, the ratio should be smaller than 1.0 (i.e., ratio
|
152 |
+
< 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
|
153 |
+
ratio > 1.0).
|
154 |
+
2) The order of output_size should be [out_h, out_w].
|
155 |
+
interp_mode (str): The mode of interpolation for resizing.
|
156 |
+
Default: 'bilinear'.
|
157 |
+
align_corners (bool): Whether align corners. Default: False.
|
158 |
+
|
159 |
+
Returns:
|
160 |
+
Tensor: Resized flow.
|
161 |
+
"""
|
162 |
+
_, _, flow_h, flow_w = flow.size()
|
163 |
+
if size_type == 'ratio':
|
164 |
+
output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
|
165 |
+
elif size_type == 'shape':
|
166 |
+
output_h, output_w = sizes[0], sizes[1]
|
167 |
+
else:
|
168 |
+
raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.')
|
169 |
+
|
170 |
+
input_flow = flow.clone()
|
171 |
+
ratio_h = output_h / flow_h
|
172 |
+
ratio_w = output_w / flow_w
|
173 |
+
input_flow[:, 0, :, :] *= ratio_w
|
174 |
+
input_flow[:, 1, :, :] *= ratio_h
|
175 |
+
resized_flow = F.interpolate(
|
176 |
+
input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners)
|
177 |
+
return resized_flow
|
178 |
+
|
179 |
+
|
180 |
+
# TODO: may write a cpp file
|
181 |
+
def pixel_unshuffle(x, scale):
|
182 |
+
""" Pixel unshuffle.
|
183 |
+
|
184 |
+
Args:
|
185 |
+
x (Tensor): Input feature with shape (b, c, hh, hw).
|
186 |
+
scale (int): Downsample ratio.
|
187 |
+
|
188 |
+
Returns:
|
189 |
+
Tensor: the pixel unshuffled feature.
|
190 |
+
"""
|
191 |
+
b, c, hh, hw = x.size()
|
192 |
+
out_channel = c * (scale**2)
|
193 |
+
assert hh % scale == 0 and hw % scale == 0
|
194 |
+
h = hh // scale
|
195 |
+
w = hw // scale
|
196 |
+
x_view = x.view(b, c, h, scale, w, scale)
|
197 |
+
return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
|
upscaler/RealESRGAN/model.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from torch.nn import functional as F
|
4 |
+
from PIL import Image
|
5 |
+
import numpy as np
|
6 |
+
import cv2
|
7 |
+
|
8 |
+
from .rrdbnet_arch import RRDBNet
|
9 |
+
from .utils import pad_reflect, split_image_into_overlapping_patches, stich_together, \
|
10 |
+
unpad_image
|
11 |
+
|
12 |
+
|
13 |
+
HF_MODELS = {
|
14 |
+
2: dict(
|
15 |
+
repo_id='sberbank-ai/Real-ESRGAN',
|
16 |
+
filename='RealESRGAN_x2.pth',
|
17 |
+
),
|
18 |
+
4: dict(
|
19 |
+
repo_id='sberbank-ai/Real-ESRGAN',
|
20 |
+
filename='RealESRGAN_x4.pth',
|
21 |
+
),
|
22 |
+
8: dict(
|
23 |
+
repo_id='sberbank-ai/Real-ESRGAN',
|
24 |
+
filename='RealESRGAN_x8.pth',
|
25 |
+
),
|
26 |
+
}
|
27 |
+
|
28 |
+
|
29 |
+
class RealESRGAN:
|
30 |
+
def __init__(self, device, scale=4):
|
31 |
+
self.device = device
|
32 |
+
self.scale = scale
|
33 |
+
self.model = RRDBNet(
|
34 |
+
num_in_ch=3, num_out_ch=3, num_feat=64,
|
35 |
+
num_block=23, num_grow_ch=32, scale=scale
|
36 |
+
)
|
37 |
+
|
38 |
+
def load_weights(self, model_path, download=True):
|
39 |
+
if not os.path.exists(model_path) and download:
|
40 |
+
from huggingface_hub import hf_hub_url, cached_download
|
41 |
+
assert self.scale in [2,4,8], 'You can download models only with scales: 2, 4, 8'
|
42 |
+
config = HF_MODELS[self.scale]
|
43 |
+
cache_dir = os.path.dirname(model_path)
|
44 |
+
local_filename = os.path.basename(model_path)
|
45 |
+
config_file_url = hf_hub_url(repo_id=config['repo_id'], filename=config['filename'])
|
46 |
+
cached_download(config_file_url, cache_dir=cache_dir, force_filename=local_filename)
|
47 |
+
print('Weights downloaded to:', os.path.join(cache_dir, local_filename))
|
48 |
+
|
49 |
+
loadnet = torch.load(model_path)
|
50 |
+
if 'params' in loadnet:
|
51 |
+
self.model.load_state_dict(loadnet['params'], strict=True)
|
52 |
+
elif 'params_ema' in loadnet:
|
53 |
+
self.model.load_state_dict(loadnet['params_ema'], strict=True)
|
54 |
+
else:
|
55 |
+
self.model.load_state_dict(loadnet, strict=True)
|
56 |
+
self.model.eval()
|
57 |
+
self.model.to(self.device)
|
58 |
+
|
59 |
+
@torch.cuda.amp.autocast()
|
60 |
+
def predict(self, lr_image, batch_size=4, patches_size=192,
|
61 |
+
padding=24, pad_size=15):
|
62 |
+
scale = self.scale
|
63 |
+
device = self.device
|
64 |
+
lr_image = np.array(lr_image)
|
65 |
+
lr_image = pad_reflect(lr_image, pad_size)
|
66 |
+
|
67 |
+
patches, p_shape = split_image_into_overlapping_patches(
|
68 |
+
lr_image, patch_size=patches_size, padding_size=padding
|
69 |
+
)
|
70 |
+
img = torch.FloatTensor(patches/255).permute((0,3,1,2)).to(device).detach()
|
71 |
+
|
72 |
+
with torch.no_grad():
|
73 |
+
res = self.model(img[0:batch_size])
|
74 |
+
for i in range(batch_size, img.shape[0], batch_size):
|
75 |
+
res = torch.cat((res, self.model(img[i:i+batch_size])), 0)
|
76 |
+
|
77 |
+
sr_image = res.permute((0,2,3,1)).clamp_(0, 1).cpu()
|
78 |
+
np_sr_image = sr_image.numpy()
|
79 |
+
|
80 |
+
padded_size_scaled = tuple(np.multiply(p_shape[0:2], scale)) + (3,)
|
81 |
+
scaled_image_shape = tuple(np.multiply(lr_image.shape[0:2], scale)) + (3,)
|
82 |
+
np_sr_image = stich_together(
|
83 |
+
np_sr_image, padded_image_shape=padded_size_scaled,
|
84 |
+
target_shape=scaled_image_shape, padding_size=padding * scale
|
85 |
+
)
|
86 |
+
sr_img = (np_sr_image*255).astype(np.uint8)
|
87 |
+
sr_img = unpad_image(sr_img, pad_size*scale)
|
88 |
+
#sr_img = Image.fromarray(sr_img)
|
89 |
+
|
90 |
+
return sr_img
|
upscaler/RealESRGAN/rrdbnet_arch.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn as nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
from .arch_utils import default_init_weights, make_layer, pixel_unshuffle
|
6 |
+
|
7 |
+
|
8 |
+
class ResidualDenseBlock(nn.Module):
|
9 |
+
"""Residual Dense Block.
|
10 |
+
|
11 |
+
Used in RRDB block in ESRGAN.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
num_feat (int): Channel number of intermediate features.
|
15 |
+
num_grow_ch (int): Channels for each growth.
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, num_feat=64, num_grow_ch=32):
|
19 |
+
super(ResidualDenseBlock, self).__init__()
|
20 |
+
self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
|
21 |
+
self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
|
22 |
+
self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
|
23 |
+
self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
|
24 |
+
self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
|
25 |
+
|
26 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
27 |
+
|
28 |
+
# initialization
|
29 |
+
default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
x1 = self.lrelu(self.conv1(x))
|
33 |
+
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
34 |
+
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
35 |
+
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
36 |
+
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
37 |
+
# Emperically, we use 0.2 to scale the residual for better performance
|
38 |
+
return x5 * 0.2 + x
|
39 |
+
|
40 |
+
|
41 |
+
class RRDB(nn.Module):
|
42 |
+
"""Residual in Residual Dense Block.
|
43 |
+
|
44 |
+
Used in RRDB-Net in ESRGAN.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
num_feat (int): Channel number of intermediate features.
|
48 |
+
num_grow_ch (int): Channels for each growth.
|
49 |
+
"""
|
50 |
+
|
51 |
+
def __init__(self, num_feat, num_grow_ch=32):
|
52 |
+
super(RRDB, self).__init__()
|
53 |
+
self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
|
54 |
+
self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
|
55 |
+
self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
|
56 |
+
|
57 |
+
def forward(self, x):
|
58 |
+
out = self.rdb1(x)
|
59 |
+
out = self.rdb2(out)
|
60 |
+
out = self.rdb3(out)
|
61 |
+
# Emperically, we use 0.2 to scale the residual for better performance
|
62 |
+
return out * 0.2 + x
|
63 |
+
|
64 |
+
|
65 |
+
class RRDBNet(nn.Module):
|
66 |
+
"""Networks consisting of Residual in Residual Dense Block, which is used
|
67 |
+
in ESRGAN.
|
68 |
+
|
69 |
+
ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
|
70 |
+
|
71 |
+
We extend ESRGAN for scale x2 and scale x1.
|
72 |
+
Note: This is one option for scale 1, scale 2 in RRDBNet.
|
73 |
+
We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
|
74 |
+
and enlarge the channel size before feeding inputs into the main ESRGAN architecture.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
num_in_ch (int): Channel number of inputs.
|
78 |
+
num_out_ch (int): Channel number of outputs.
|
79 |
+
num_feat (int): Channel number of intermediate features.
|
80 |
+
Default: 64
|
81 |
+
num_block (int): Block number in the trunk network. Defaults: 23
|
82 |
+
num_grow_ch (int): Channels for each growth. Default: 32.
|
83 |
+
"""
|
84 |
+
|
85 |
+
def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
|
86 |
+
super(RRDBNet, self).__init__()
|
87 |
+
self.scale = scale
|
88 |
+
if scale == 2:
|
89 |
+
num_in_ch = num_in_ch * 4
|
90 |
+
elif scale == 1:
|
91 |
+
num_in_ch = num_in_ch * 16
|
92 |
+
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
|
93 |
+
self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
|
94 |
+
self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
95 |
+
# upsample
|
96 |
+
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
97 |
+
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
98 |
+
if scale == 8:
|
99 |
+
self.conv_up3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
100 |
+
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
101 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
102 |
+
|
103 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
104 |
+
|
105 |
+
def forward(self, x):
|
106 |
+
if self.scale == 2:
|
107 |
+
feat = pixel_unshuffle(x, scale=2)
|
108 |
+
elif self.scale == 1:
|
109 |
+
feat = pixel_unshuffle(x, scale=4)
|
110 |
+
else:
|
111 |
+
feat = x
|
112 |
+
feat = self.conv_first(feat)
|
113 |
+
body_feat = self.conv_body(self.body(feat))
|
114 |
+
feat = feat + body_feat
|
115 |
+
# upsample
|
116 |
+
feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
|
117 |
+
feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
|
118 |
+
if self.scale == 8:
|
119 |
+
feat = self.lrelu(self.conv_up3(F.interpolate(feat, scale_factor=2, mode='nearest')))
|
120 |
+
out = self.conv_last(self.lrelu(self.conv_hr(feat)))
|
121 |
+
return out
|
upscaler/RealESRGAN/utils.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from PIL import Image
|
4 |
+
import os
|
5 |
+
import io
|
6 |
+
|
7 |
+
def pad_reflect(image, pad_size):
|
8 |
+
imsize = image.shape
|
9 |
+
height, width = imsize[:2]
|
10 |
+
new_img = np.zeros([height+pad_size*2, width+pad_size*2, imsize[2]]).astype(np.uint8)
|
11 |
+
new_img[pad_size:-pad_size, pad_size:-pad_size, :] = image
|
12 |
+
|
13 |
+
new_img[0:pad_size, pad_size:-pad_size, :] = np.flip(image[0:pad_size, :, :], axis=0) #top
|
14 |
+
new_img[-pad_size:, pad_size:-pad_size, :] = np.flip(image[-pad_size:, :, :], axis=0) #bottom
|
15 |
+
new_img[:, 0:pad_size, :] = np.flip(new_img[:, pad_size:pad_size*2, :], axis=1) #left
|
16 |
+
new_img[:, -pad_size:, :] = np.flip(new_img[:, -pad_size*2:-pad_size, :], axis=1) #right
|
17 |
+
|
18 |
+
return new_img
|
19 |
+
|
20 |
+
def unpad_image(image, pad_size):
|
21 |
+
return image[pad_size:-pad_size, pad_size:-pad_size, :]
|
22 |
+
|
23 |
+
|
24 |
+
def process_array(image_array, expand=True):
|
25 |
+
""" Process a 3-dimensional array into a scaled, 4 dimensional batch of size 1. """
|
26 |
+
|
27 |
+
image_batch = image_array / 255.0
|
28 |
+
if expand:
|
29 |
+
image_batch = np.expand_dims(image_batch, axis=0)
|
30 |
+
return image_batch
|
31 |
+
|
32 |
+
|
33 |
+
def process_output(output_tensor):
|
34 |
+
""" Transforms the 4-dimensional output tensor into a suitable image format. """
|
35 |
+
|
36 |
+
sr_img = output_tensor.clip(0, 1) * 255
|
37 |
+
sr_img = np.uint8(sr_img)
|
38 |
+
return sr_img
|
39 |
+
|
40 |
+
|
41 |
+
def pad_patch(image_patch, padding_size, channel_last=True):
|
42 |
+
""" Pads image_patch with with padding_size edge values. """
|
43 |
+
|
44 |
+
if channel_last:
|
45 |
+
return np.pad(
|
46 |
+
image_patch,
|
47 |
+
((padding_size, padding_size), (padding_size, padding_size), (0, 0)),
|
48 |
+
'edge',
|
49 |
+
)
|
50 |
+
else:
|
51 |
+
return np.pad(
|
52 |
+
image_patch,
|
53 |
+
((0, 0), (padding_size, padding_size), (padding_size, padding_size)),
|
54 |
+
'edge',
|
55 |
+
)
|
56 |
+
|
57 |
+
|
58 |
+
def unpad_patches(image_patches, padding_size):
|
59 |
+
return image_patches[:, padding_size:-padding_size, padding_size:-padding_size, :]
|
60 |
+
|
61 |
+
|
62 |
+
def split_image_into_overlapping_patches(image_array, patch_size, padding_size=2):
|
63 |
+
""" Splits the image into partially overlapping patches.
|
64 |
+
The patches overlap by padding_size pixels.
|
65 |
+
Pads the image twice:
|
66 |
+
- first to have a size multiple of the patch size,
|
67 |
+
- then to have equal padding at the borders.
|
68 |
+
Args:
|
69 |
+
image_array: numpy array of the input image.
|
70 |
+
patch_size: size of the patches from the original image (without padding).
|
71 |
+
padding_size: size of the overlapping area.
|
72 |
+
"""
|
73 |
+
|
74 |
+
xmax, ymax, _ = image_array.shape
|
75 |
+
x_remainder = xmax % patch_size
|
76 |
+
y_remainder = ymax % patch_size
|
77 |
+
|
78 |
+
# modulo here is to avoid extending of patch_size instead of 0
|
79 |
+
x_extend = (patch_size - x_remainder) % patch_size
|
80 |
+
y_extend = (patch_size - y_remainder) % patch_size
|
81 |
+
|
82 |
+
# make sure the image is divisible into regular patches
|
83 |
+
extended_image = np.pad(image_array, ((0, x_extend), (0, y_extend), (0, 0)), 'edge')
|
84 |
+
|
85 |
+
# add padding around the image to simplify computations
|
86 |
+
padded_image = pad_patch(extended_image, padding_size, channel_last=True)
|
87 |
+
|
88 |
+
xmax, ymax, _ = padded_image.shape
|
89 |
+
patches = []
|
90 |
+
|
91 |
+
x_lefts = range(padding_size, xmax - padding_size, patch_size)
|
92 |
+
y_tops = range(padding_size, ymax - padding_size, patch_size)
|
93 |
+
|
94 |
+
for x in x_lefts:
|
95 |
+
for y in y_tops:
|
96 |
+
x_left = x - padding_size
|
97 |
+
y_top = y - padding_size
|
98 |
+
x_right = x + patch_size + padding_size
|
99 |
+
y_bottom = y + patch_size + padding_size
|
100 |
+
patch = padded_image[x_left:x_right, y_top:y_bottom, :]
|
101 |
+
patches.append(patch)
|
102 |
+
|
103 |
+
return np.array(patches), padded_image.shape
|
104 |
+
|
105 |
+
|
106 |
+
def stich_together(patches, padded_image_shape, target_shape, padding_size=4):
|
107 |
+
""" Reconstruct the image from overlapping patches.
|
108 |
+
After scaling, shapes and padding should be scaled too.
|
109 |
+
Args:
|
110 |
+
patches: patches obtained with split_image_into_overlapping_patches
|
111 |
+
padded_image_shape: shape of the padded image contructed in split_image_into_overlapping_patches
|
112 |
+
target_shape: shape of the final image
|
113 |
+
padding_size: size of the overlapping area.
|
114 |
+
"""
|
115 |
+
|
116 |
+
xmax, ymax, _ = padded_image_shape
|
117 |
+
patches = unpad_patches(patches, padding_size)
|
118 |
+
patch_size = patches.shape[1]
|
119 |
+
n_patches_per_row = ymax // patch_size
|
120 |
+
|
121 |
+
complete_image = np.zeros((xmax, ymax, 3))
|
122 |
+
|
123 |
+
row = -1
|
124 |
+
col = 0
|
125 |
+
for i in range(len(patches)):
|
126 |
+
if i % n_patches_per_row == 0:
|
127 |
+
row += 1
|
128 |
+
col = 0
|
129 |
+
complete_image[
|
130 |
+
row * patch_size: (row + 1) * patch_size, col * patch_size: (col + 1) * patch_size,:
|
131 |
+
] = patches[i]
|
132 |
+
col += 1
|
133 |
+
return complete_image[0: target_shape[0], 0: target_shape[1], :]
|
upscaler/__init__.py
ADDED
File without changes
|
upscaler/codeformer.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import torch
|
3 |
+
import onnx
|
4 |
+
import onnxruntime
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
import time
|
8 |
+
|
9 |
+
# codeformer converted to onnx
|
10 |
+
# using https://github.com/redthing1/CodeFormer
|
11 |
+
|
12 |
+
|
13 |
+
class CodeFormerEnhancer:
|
14 |
+
def __init__(self, model_path="codeformer.onnx", device='cpu'):
|
15 |
+
model = onnx.load(model_path)
|
16 |
+
session_options = onnxruntime.SessionOptions()
|
17 |
+
session_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
18 |
+
providers = ["CPUExecutionProvider"]
|
19 |
+
if device == 'cuda':
|
20 |
+
providers = [("CUDAExecutionProvider", {"cudnn_conv_algo_search": "DEFAULT"}),"CPUExecutionProvider"]
|
21 |
+
self.session = onnxruntime.InferenceSession(model_path, sess_options=session_options, providers=providers)
|
22 |
+
|
23 |
+
def enhance(self, img, w=0.9):
|
24 |
+
img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
|
25 |
+
img = img.astype(np.float32)[:,:,::-1] / 255.0
|
26 |
+
img = img.transpose((2, 0, 1))
|
27 |
+
nrm_mean = np.array([0.5, 0.5, 0.5]).reshape((-1, 1, 1))
|
28 |
+
nrm_std = np.array([0.5, 0.5, 0.5]).reshape((-1, 1, 1))
|
29 |
+
img = (img - nrm_mean) / nrm_std
|
30 |
+
|
31 |
+
img = np.expand_dims(img, axis=0)
|
32 |
+
|
33 |
+
out = self.session.run(None, {'x':img.astype(np.float32), 'w':np.array([w], dtype=np.double)})[0]
|
34 |
+
out = (out[0].transpose(1,2,0).clip(-1,1) + 1) * 0.5
|
35 |
+
out = (out * 255)[:,:,::-1]
|
36 |
+
|
37 |
+
return out.astype('uint8')
|
utils.py
ADDED
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import time
|
4 |
+
import glob
|
5 |
+
import shutil
|
6 |
+
import platform
|
7 |
+
import datetime
|
8 |
+
import subprocess
|
9 |
+
import numpy as np
|
10 |
+
from threading import Thread
|
11 |
+
from moviepy.editor import VideoFileClip, ImageSequenceClip
|
12 |
+
from moviepy.video.io.ffmpeg_tools import ffmpeg_extract_subclip
|
13 |
+
|
14 |
+
|
15 |
+
logo_image = cv2.imread("./assets/images/logo.png", cv2.IMREAD_UNCHANGED)
|
16 |
+
|
17 |
+
|
18 |
+
quality_types = ["poor", "low", "medium", "high", "best"]
|
19 |
+
|
20 |
+
|
21 |
+
bitrate_quality_by_resolution = {
|
22 |
+
240: {"poor": "300k", "low": "500k", "medium": "800k", "high": "1000k", "best": "1200k"},
|
23 |
+
360: {"poor": "500k","low": "800k","medium": "1200k","high": "1500k","best": "2000k"},
|
24 |
+
480: {"poor": "800k","low": "1200k","medium": "2000k","high": "2500k","best": "3000k"},
|
25 |
+
720: {"poor": "1500k","low": "2500k","medium": "4000k","high": "5000k","best": "6000k"},
|
26 |
+
1080: {"poor": "2500k","low": "4000k","medium": "6000k","high": "7000k","best": "8000k"},
|
27 |
+
1440: {"poor": "4000k","low": "6000k","medium": "8000k","high": "10000k","best": "12000k"},
|
28 |
+
2160: {"poor": "8000k","low": "10000k","medium": "12000k","high": "15000k","best": "20000k"}
|
29 |
+
}
|
30 |
+
|
31 |
+
|
32 |
+
crf_quality_by_resolution = {
|
33 |
+
240: {"poor": 45, "low": 35, "medium": 28, "high": 23, "best": 20},
|
34 |
+
360: {"poor": 35, "low": 28, "medium": 23, "high": 20, "best": 18},
|
35 |
+
480: {"poor": 28, "low": 23, "medium": 20, "high": 18, "best": 16},
|
36 |
+
720: {"poor": 23, "low": 20, "medium": 18, "high": 16, "best": 14},
|
37 |
+
1080: {"poor": 20, "low": 18, "medium": 16, "high": 14, "best": 12},
|
38 |
+
1440: {"poor": 18, "low": 16, "medium": 14, "high": 12, "best": 10},
|
39 |
+
2160: {"poor": 16, "low": 14, "medium": 12, "high": 10, "best": 8}
|
40 |
+
}
|
41 |
+
|
42 |
+
|
43 |
+
def get_bitrate_for_resolution(resolution, quality):
|
44 |
+
available_resolutions = list(bitrate_quality_by_resolution.keys())
|
45 |
+
closest_resolution = min(available_resolutions, key=lambda x: abs(x - resolution))
|
46 |
+
return bitrate_quality_by_resolution[closest_resolution][quality]
|
47 |
+
|
48 |
+
|
49 |
+
def get_crf_for_resolution(resolution, quality):
|
50 |
+
available_resolutions = list(crf_quality_by_resolution.keys())
|
51 |
+
closest_resolution = min(available_resolutions, key=lambda x: abs(x - resolution))
|
52 |
+
return crf_quality_by_resolution[closest_resolution][quality]
|
53 |
+
|
54 |
+
|
55 |
+
def get_video_bitrate(video_file):
|
56 |
+
ffprobe_cmd = ['ffprobe', '-v', 'error', '-select_streams', 'v:0', '-show_entries',
|
57 |
+
'stream=bit_rate', '-of', 'default=noprint_wrappers=1:nokey=1', video_file]
|
58 |
+
result = subprocess.run(ffprobe_cmd, stdout=subprocess.PIPE)
|
59 |
+
kbps = max(int(result.stdout) // 1000, 10)
|
60 |
+
return str(kbps) + 'k'
|
61 |
+
|
62 |
+
|
63 |
+
def trim_video(video_path, output_path, start_frame, stop_frame):
|
64 |
+
video_name, _ = os.path.splitext(os.path.basename(video_path))
|
65 |
+
trimmed_video_filename = video_name + "_trimmed" + ".mp4"
|
66 |
+
temp_path = os.path.join(output_path, "trim")
|
67 |
+
os.makedirs(temp_path, exist_ok=True)
|
68 |
+
trimmed_video_file_path = os.path.join(temp_path, trimmed_video_filename)
|
69 |
+
|
70 |
+
video = VideoFileClip(video_path, fps_source="fps")
|
71 |
+
fps = video.fps
|
72 |
+
start_time = start_frame / fps
|
73 |
+
duration = (stop_frame - start_frame) / fps
|
74 |
+
|
75 |
+
bitrate = get_bitrate_for_resolution(min(*video.size), "high")
|
76 |
+
|
77 |
+
trimmed_video = video.subclip(start_time, start_time + duration)
|
78 |
+
trimmed_video.write_videofile(
|
79 |
+
trimmed_video_file_path, codec="libx264", audio_codec="aac", bitrate=bitrate,
|
80 |
+
)
|
81 |
+
trimmed_video.close()
|
82 |
+
video.close()
|
83 |
+
|
84 |
+
return trimmed_video_file_path
|
85 |
+
|
86 |
+
|
87 |
+
def open_directory(path=None):
|
88 |
+
if path is None:
|
89 |
+
return
|
90 |
+
try:
|
91 |
+
os.startfile(path)
|
92 |
+
except:
|
93 |
+
subprocess.Popen(["xdg-open", path])
|
94 |
+
|
95 |
+
|
96 |
+
class StreamerThread(object):
|
97 |
+
def __init__(self, src=0):
|
98 |
+
self.capture = cv2.VideoCapture(src)
|
99 |
+
self.capture.set(cv2.CAP_PROP_BUFFERSIZE, 2)
|
100 |
+
self.FPS = 1 / 30
|
101 |
+
self.FPS_MS = int(self.FPS * 1000)
|
102 |
+
self.thread = None
|
103 |
+
self.stopped = False
|
104 |
+
self.frame = None
|
105 |
+
|
106 |
+
def start(self):
|
107 |
+
self.thread = Thread(target=self.update, args=())
|
108 |
+
self.thread.daemon = True
|
109 |
+
self.thread.start()
|
110 |
+
|
111 |
+
def stop(self):
|
112 |
+
self.stopped = True
|
113 |
+
self.thread.join()
|
114 |
+
print("stopped")
|
115 |
+
|
116 |
+
def update(self):
|
117 |
+
while not self.stopped:
|
118 |
+
if self.capture.isOpened():
|
119 |
+
(self.status, self.frame) = self.capture.read()
|
120 |
+
time.sleep(self.FPS)
|
121 |
+
|
122 |
+
|
123 |
+
class ProcessBar:
|
124 |
+
def __init__(self, bar_length, total, before="⬛", after="🟨"):
|
125 |
+
self.bar_length = bar_length
|
126 |
+
self.total = total
|
127 |
+
self.before = before
|
128 |
+
self.after = after
|
129 |
+
self.bar = [self.before] * bar_length
|
130 |
+
self.start_time = time.time()
|
131 |
+
|
132 |
+
def get(self, index):
|
133 |
+
total = self.total
|
134 |
+
elapsed_time = time.time() - self.start_time
|
135 |
+
average_time_per_iteration = elapsed_time / (index + 1)
|
136 |
+
remaining_iterations = total - (index + 1)
|
137 |
+
estimated_remaining_time = remaining_iterations * average_time_per_iteration
|
138 |
+
|
139 |
+
self.bar[int(index / total * self.bar_length)] = self.after
|
140 |
+
info_text = f"({index+1}/{total}) {''.join(self.bar)} "
|
141 |
+
info_text += f"(ETR: {int(estimated_remaining_time // 60)} min {int(estimated_remaining_time % 60)} sec)"
|
142 |
+
return info_text
|
143 |
+
|
144 |
+
|
145 |
+
def add_logo_to_image(img, logo=logo_image):
|
146 |
+
logo_size = int(img.shape[1] * 0.1)
|
147 |
+
logo = cv2.resize(logo, (logo_size, logo_size))
|
148 |
+
if logo.shape[2] == 4:
|
149 |
+
alpha = logo[:, :, 3]
|
150 |
+
else:
|
151 |
+
alpha = np.ones_like(logo[:, :, 0]) * 255
|
152 |
+
padding = int(logo_size * 0.1)
|
153 |
+
roi = img.shape[0] - logo_size - padding, img.shape[1] - logo_size - padding
|
154 |
+
for c in range(0, 3):
|
155 |
+
img[roi[0] : roi[0] + logo_size, roi[1] : roi[1] + logo_size, c] = (
|
156 |
+
alpha / 255.0
|
157 |
+
) * logo[:, :, c] + (1 - alpha / 255.0) * img[
|
158 |
+
roi[0] : roi[0] + logo_size, roi[1] : roi[1] + logo_size, c
|
159 |
+
]
|
160 |
+
return img
|
161 |
+
|
162 |
+
|
163 |
+
def split_list_by_lengths(data, length_list):
|
164 |
+
split_data = []
|
165 |
+
start_idx = 0
|
166 |
+
for length in length_list:
|
167 |
+
end_idx = start_idx + length
|
168 |
+
sublist = data[start_idx:end_idx]
|
169 |
+
split_data.append(sublist)
|
170 |
+
start_idx = end_idx
|
171 |
+
return split_data
|
172 |
+
|
173 |
+
|
174 |
+
def merge_img_sequence_from_ref(ref_video_path, image_sequence, output_file_name):
|
175 |
+
video_clip = VideoFileClip(ref_video_path, fps_source="fps")
|
176 |
+
fps = video_clip.fps
|
177 |
+
duration = video_clip.duration
|
178 |
+
total_frames = video_clip.reader.nframes
|
179 |
+
audio_clip = video_clip.audio if video_clip.audio is not None else None
|
180 |
+
edited_video_clip = ImageSequenceClip(image_sequence, fps=fps)
|
181 |
+
|
182 |
+
if audio_clip is not None:
|
183 |
+
edited_video_clip = edited_video_clip.set_audio(audio_clip)
|
184 |
+
|
185 |
+
bitrate = get_bitrate_for_resolution(min(*edited_video_clip.size), "high")
|
186 |
+
|
187 |
+
edited_video_clip.set_duration(duration).write_videofile(
|
188 |
+
output_file_name, codec="libx264", bitrate=bitrate,
|
189 |
+
)
|
190 |
+
edited_video_clip.close()
|
191 |
+
video_clip.close()
|
192 |
+
|
193 |
+
|
194 |
+
def scale_bbox_from_center(bbox, scale_width, scale_height, image_width, image_height):
|
195 |
+
# Extract the coordinates of the bbox
|
196 |
+
x1, y1, x2, y2 = bbox
|
197 |
+
|
198 |
+
# Calculate the center point of the bbox
|
199 |
+
center_x = (x1 + x2) / 2
|
200 |
+
center_y = (y1 + y2) / 2
|
201 |
+
|
202 |
+
# Calculate the new width and height of the bbox based on the scaling factors
|
203 |
+
width = x2 - x1
|
204 |
+
height = y2 - y1
|
205 |
+
new_width = width * scale_width
|
206 |
+
new_height = height * scale_height
|
207 |
+
|
208 |
+
# Calculate the new coordinates of the bbox, considering the image boundaries
|
209 |
+
new_x1 = center_x - new_width / 2
|
210 |
+
new_y1 = center_y - new_height / 2
|
211 |
+
new_x2 = center_x + new_width / 2
|
212 |
+
new_y2 = center_y + new_height / 2
|
213 |
+
|
214 |
+
# Adjust the coordinates to ensure the bbox remains within the image boundaries
|
215 |
+
new_x1 = max(0, new_x1)
|
216 |
+
new_y1 = max(0, new_y1)
|
217 |
+
new_x2 = min(image_width - 1, new_x2)
|
218 |
+
new_y2 = min(image_height - 1, new_y2)
|
219 |
+
|
220 |
+
# Return the scaled bbox coordinates
|
221 |
+
scaled_bbox = [new_x1, new_y1, new_x2, new_y2]
|
222 |
+
return scaled_bbox
|
223 |
+
|
224 |
+
|
225 |
+
def laplacian_blending(A, B, m, num_levels=7):
|
226 |
+
assert A.shape == B.shape
|
227 |
+
assert B.shape == m.shape
|
228 |
+
height = m.shape[0]
|
229 |
+
width = m.shape[1]
|
230 |
+
size_list = np.array([4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192])
|
231 |
+
size = size_list[np.where(size_list > max(height, width))][0]
|
232 |
+
GA = np.zeros((size, size, 3), dtype=np.float32)
|
233 |
+
GA[:height, :width, :] = A
|
234 |
+
GB = np.zeros((size, size, 3), dtype=np.float32)
|
235 |
+
GB[:height, :width, :] = B
|
236 |
+
GM = np.zeros((size, size, 3), dtype=np.float32)
|
237 |
+
GM[:height, :width, :] = m
|
238 |
+
gpA = [GA]
|
239 |
+
gpB = [GB]
|
240 |
+
gpM = [GM]
|
241 |
+
for i in range(num_levels):
|
242 |
+
GA = cv2.pyrDown(GA)
|
243 |
+
GB = cv2.pyrDown(GB)
|
244 |
+
GM = cv2.pyrDown(GM)
|
245 |
+
gpA.append(np.float32(GA))
|
246 |
+
gpB.append(np.float32(GB))
|
247 |
+
gpM.append(np.float32(GM))
|
248 |
+
lpA = [gpA[num_levels-1]]
|
249 |
+
lpB = [gpB[num_levels-1]]
|
250 |
+
gpMr = [gpM[num_levels-1]]
|
251 |
+
for i in range(num_levels-1,0,-1):
|
252 |
+
LA = np.subtract(gpA[i-1], cv2.pyrUp(gpA[i]))
|
253 |
+
LB = np.subtract(gpB[i-1], cv2.pyrUp(gpB[i]))
|
254 |
+
lpA.append(LA)
|
255 |
+
lpB.append(LB)
|
256 |
+
gpMr.append(gpM[i-1])
|
257 |
+
LS = []
|
258 |
+
for la,lb,gm in zip(lpA,lpB,gpMr):
|
259 |
+
ls = la * gm + lb * (1.0 - gm)
|
260 |
+
LS.append(ls)
|
261 |
+
ls_ = LS[0]
|
262 |
+
for i in range(1,num_levels):
|
263 |
+
ls_ = cv2.pyrUp(ls_)
|
264 |
+
ls_ = cv2.add(ls_, LS[i])
|
265 |
+
ls_ = ls_[:height, :width, :]
|
266 |
+
#ls_ = (ls_ - np.min(ls_)) * (255.0 / (np.max(ls_) - np.min(ls_)))
|
267 |
+
return ls_.clip(0, 255)
|
268 |
+
|
269 |
+
|
270 |
+
def mask_crop(mask, crop):
|
271 |
+
top, bottom, left, right = crop
|
272 |
+
shape = mask.shape
|
273 |
+
top = int(top)
|
274 |
+
bottom = int(bottom)
|
275 |
+
if top + bottom < shape[1]:
|
276 |
+
if top > 0: mask[:top, :] = 0
|
277 |
+
if bottom > 0: mask[-bottom:, :] = 0
|
278 |
+
|
279 |
+
left = int(left)
|
280 |
+
right = int(right)
|
281 |
+
if left + right < shape[0]:
|
282 |
+
if left > 0: mask[:, :left] = 0
|
283 |
+
if right > 0: mask[:, -right:] = 0
|
284 |
+
|
285 |
+
return mask
|
286 |
+
|
287 |
+
def create_image_grid(images, size=128):
|
288 |
+
num_images = len(images)
|
289 |
+
num_cols = int(np.ceil(np.sqrt(num_images)))
|
290 |
+
num_rows = int(np.ceil(num_images / num_cols))
|
291 |
+
grid = np.zeros((num_rows * size, num_cols * size, 3), dtype=np.uint8)
|
292 |
+
|
293 |
+
for i, image in enumerate(images):
|
294 |
+
row_idx = (i // num_cols) * size
|
295 |
+
col_idx = (i % num_cols) * size
|
296 |
+
image = cv2.resize(image.copy(), (size,size))
|
297 |
+
if image.dtype != np.uint8:
|
298 |
+
image = (image.astype('float32') * 255).astype('uint8')
|
299 |
+
if image.ndim == 2:
|
300 |
+
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
|
301 |
+
grid[row_idx:row_idx + size, col_idx:col_idx + size] = image
|
302 |
+
|
303 |
+
return grid
|