Suniilkumaar bluefoxcreation commited on
Commit
0fc4c70
·
0 Parent(s):

Duplicate from bluefoxcreation/SwapMukham

Browse files

Co-authored-by: BlueFox <[email protected]>

.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+
2
+ *.pyc
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Swap Mukham
3
+ emoji: 💻
4
+ colorFrom: purple
5
+ colorTo: gray
6
+ sdk: gradio
7
+ sdk_version: 3.35.2
8
+ app_file: app.py
9
+ pinned: false
10
+ license: unknown
11
+ duplicated_from: bluefoxcreation/SwapMukham
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,924 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import glob
4
+ import time
5
+ import torch
6
+ import shutil
7
+ import argparse
8
+ import platform
9
+ import datetime
10
+ import subprocess
11
+ import insightface
12
+ import onnxruntime
13
+ import numpy as np
14
+ import gradio as gr
15
+ import threading
16
+ import queue
17
+ from tqdm import tqdm
18
+ import concurrent.futures
19
+ from moviepy.editor import VideoFileClip
20
+
21
+ from nsfw_checker import NSFWChecker
22
+ from face_swapper import Inswapper, paste_to_whole
23
+ from face_analyser import detect_conditions, get_analysed_data, swap_options_list
24
+ from face_parsing import init_parsing_model, get_parsed_mask, mask_regions, mask_regions_to_list
25
+ from face_enhancer import get_available_enhancer_names, load_face_enhancer_model, cv2_interpolations
26
+ from utils import trim_video, StreamerThread, ProcessBar, open_directory, split_list_by_lengths, merge_img_sequence_from_ref, create_image_grid
27
+
28
+ ## ------------------------------ USER ARGS ------------------------------
29
+
30
+ parser = argparse.ArgumentParser(description="Swap-Mukham Face Swapper")
31
+ parser.add_argument("--out_dir", help="Default Output directory", default=os.getcwd())
32
+ parser.add_argument("--batch_size", help="Gpu batch size", default=32)
33
+ parser.add_argument("--cuda", action="store_true", help="Enable cuda", default=False)
34
+ parser.add_argument(
35
+ "--colab", action="store_true", help="Enable colab mode", default=False
36
+ )
37
+ user_args = parser.parse_args()
38
+
39
+ ## ------------------------------ DEFAULTS ------------------------------
40
+
41
+ USE_COLAB = user_args.colab
42
+ USE_CUDA = user_args.cuda
43
+ DEF_OUTPUT_PATH = user_args.out_dir
44
+ BATCH_SIZE = int(user_args.batch_size)
45
+ WORKSPACE = None
46
+ OUTPUT_FILE = None
47
+ CURRENT_FRAME = None
48
+ STREAMER = None
49
+ DETECT_CONDITION = "best detection"
50
+ DETECT_SIZE = 640
51
+ DETECT_THRESH = 0.6
52
+ NUM_OF_SRC_SPECIFIC = 10
53
+ MASK_INCLUDE = [
54
+ "Skin",
55
+ "R-Eyebrow",
56
+ "L-Eyebrow",
57
+ "L-Eye",
58
+ "R-Eye",
59
+ "Nose",
60
+ "Mouth",
61
+ "L-Lip",
62
+ "U-Lip"
63
+ ]
64
+ MASK_SOFT_KERNEL = 17
65
+ MASK_SOFT_ITERATIONS = 10
66
+ MASK_BLUR_AMOUNT = 0.1
67
+ MASK_ERODE_AMOUNT = 0.15
68
+
69
+ FACE_SWAPPER = None
70
+ FACE_ANALYSER = None
71
+ FACE_ENHANCER = None
72
+ FACE_PARSER = None
73
+ NSFW_DETECTOR = None
74
+ FACE_ENHANCER_LIST = ["NONE"]
75
+ FACE_ENHANCER_LIST.extend(get_available_enhancer_names())
76
+ FACE_ENHANCER_LIST.extend(cv2_interpolations)
77
+
78
+ ## ------------------------------ SET EXECUTION PROVIDER ------------------------------
79
+ # Note: Non CUDA users may change settings here
80
+
81
+ PROVIDER = ["CPUExecutionProvider"]
82
+
83
+ if USE_CUDA:
84
+ available_providers = onnxruntime.get_available_providers()
85
+ if "CUDAExecutionProvider" in available_providers:
86
+ print("\n********** Running on CUDA **********\n")
87
+ PROVIDER = ["CUDAExecutionProvider", "CPUExecutionProvider"]
88
+ else:
89
+ USE_CUDA = False
90
+ print("\n********** CUDA unavailable running on CPU **********\n")
91
+ else:
92
+ USE_CUDA = False
93
+ print("\n********** Running on CPU **********\n")
94
+
95
+ device = "cuda" if USE_CUDA else "cpu"
96
+ EMPTY_CACHE = lambda: torch.cuda.empty_cache() if device == "cuda" else None
97
+
98
+ ## ------------------------------ LOAD MODELS ------------------------------
99
+
100
+ def load_face_analyser_model(name="buffalo_l"):
101
+ global FACE_ANALYSER
102
+ if FACE_ANALYSER is None:
103
+ FACE_ANALYSER = insightface.app.FaceAnalysis(name=name, providers=PROVIDER)
104
+ FACE_ANALYSER.prepare(
105
+ ctx_id=0, det_size=(DETECT_SIZE, DETECT_SIZE), det_thresh=DETECT_THRESH
106
+ )
107
+
108
+
109
+ def load_face_swapper_model(path="./assets/pretrained_models/inswapper_128.onnx"):
110
+ global FACE_SWAPPER
111
+ if FACE_SWAPPER is None:
112
+ batch = int(BATCH_SIZE) if device == "cuda" else 1
113
+ FACE_SWAPPER = Inswapper(model_file=path, batch_size=batch, providers=PROVIDER)
114
+
115
+
116
+ def load_face_parser_model(path="./assets/pretrained_models/79999_iter.pth"):
117
+ global FACE_PARSER
118
+ if FACE_PARSER is None:
119
+ FACE_PARSER = init_parsing_model(path, device=device)
120
+
121
+ def load_nsfw_detector_model(path="./assets/pretrained_models/open-nsfw.onnx"):
122
+ global NSFW_DETECTOR
123
+ if NSFW_DETECTOR is None:
124
+ NSFW_DETECTOR = NSFWChecker(model_path=path, providers=PROVIDER)
125
+
126
+
127
+ load_face_analyser_model()
128
+ load_face_swapper_model()
129
+
130
+ ## ------------------------------ MAIN PROCESS ------------------------------
131
+
132
+
133
+ def process(
134
+ input_type,
135
+ image_path,
136
+ video_path,
137
+ directory_path,
138
+ source_path,
139
+ output_path,
140
+ output_name,
141
+ keep_output_sequence,
142
+ condition,
143
+ age,
144
+ distance,
145
+ face_enhancer_name,
146
+ enable_face_parser,
147
+ mask_includes,
148
+ mask_soft_kernel,
149
+ mask_soft_iterations,
150
+ blur_amount,
151
+ erode_amount,
152
+ face_scale,
153
+ enable_laplacian_blend,
154
+ crop_top,
155
+ crop_bott,
156
+ crop_left,
157
+ crop_right,
158
+ *specifics,
159
+ ):
160
+ global WORKSPACE
161
+ global OUTPUT_FILE
162
+ global PREVIEW
163
+ WORKSPACE, OUTPUT_FILE, PREVIEW = None, None, None
164
+
165
+ ## ------------------------------ GUI UPDATE FUNC ------------------------------
166
+
167
+ def ui_before():
168
+ return (
169
+ gr.update(visible=True, value=PREVIEW),
170
+ gr.update(interactive=False),
171
+ gr.update(interactive=False),
172
+ gr.update(visible=False),
173
+ )
174
+
175
+ def ui_after():
176
+ return (
177
+ gr.update(visible=True, value=PREVIEW),
178
+ gr.update(interactive=True),
179
+ gr.update(interactive=True),
180
+ gr.update(visible=False),
181
+ )
182
+
183
+ def ui_after_vid():
184
+ return (
185
+ gr.update(visible=False),
186
+ gr.update(interactive=True),
187
+ gr.update(interactive=True),
188
+ gr.update(value=OUTPUT_FILE, visible=True),
189
+ )
190
+
191
+ start_time = time.time()
192
+ total_exec_time = lambda start_time: divmod(time.time() - start_time, 60)
193
+ get_finsh_text = lambda start_time: f"✔️ Completed in {int(total_exec_time(start_time)[0])} min {int(total_exec_time(start_time)[1])} sec."
194
+
195
+ ## ------------------------------ PREPARE INPUTS & LOAD MODELS ------------------------------
196
+
197
+ yield "### \n ⌛ Loading NSFW detector model...", *ui_before()
198
+ load_nsfw_detector_model()
199
+
200
+ yield "### \n ⌛ Loading face analyser model...", *ui_before()
201
+ load_face_analyser_model()
202
+
203
+ yield "### \n ⌛ Loading face swapper model...", *ui_before()
204
+ load_face_swapper_model()
205
+
206
+ if face_enhancer_name != "NONE":
207
+ if face_enhancer_name not in cv2_interpolations:
208
+ yield f"### \n ⌛ Loading {face_enhancer_name} model...", *ui_before()
209
+ FACE_ENHANCER = load_face_enhancer_model(name=face_enhancer_name, device=device)
210
+ else:
211
+ FACE_ENHANCER = None
212
+
213
+ if enable_face_parser:
214
+ yield "### \n ⌛ Loading face parsing model...", *ui_before()
215
+ load_face_parser_model()
216
+
217
+ includes = mask_regions_to_list(mask_includes)
218
+ specifics = list(specifics)
219
+ half = len(specifics) // 2
220
+ sources = specifics[:half]
221
+ specifics = specifics[half:]
222
+ if crop_top > crop_bott:
223
+ crop_top, crop_bott = crop_bott, crop_top
224
+ if crop_left > crop_right:
225
+ crop_left, crop_right = crop_right, crop_left
226
+ crop_mask = (crop_top, 511-crop_bott, crop_left, 511-crop_right)
227
+
228
+ def swap_process(image_sequence):
229
+ ## ------------------------------ CONTENT CHECK ------------------------------
230
+
231
+ yield "### \n ⌛ Checking contents...", *ui_before()
232
+ nsfw = NSFW_DETECTOR.is_nsfw(image_sequence)
233
+ if nsfw:
234
+ message = "NSFW Content detected !!!"
235
+ yield f"### \n 🔞 {message}", *ui_before()
236
+ assert not nsfw, message
237
+ return False
238
+ EMPTY_CACHE()
239
+
240
+ ## ------------------------------ ANALYSE FACE ------------------------------
241
+
242
+ yield "### \n ⌛ Analysing face data...", *ui_before()
243
+ if condition != "Specific Face":
244
+ source_data = source_path, age
245
+ else:
246
+ source_data = ((sources, specifics), distance)
247
+ analysed_targets, analysed_sources, whole_frame_list, num_faces_per_frame = get_analysed_data(
248
+ FACE_ANALYSER,
249
+ image_sequence,
250
+ source_data,
251
+ swap_condition=condition,
252
+ detect_condition=DETECT_CONDITION,
253
+ scale=face_scale
254
+ )
255
+
256
+ ## ------------------------------ SWAP FUNC ------------------------------
257
+
258
+ yield "### \n ⌛ Generating faces...", *ui_before()
259
+ preds = []
260
+ matrs = []
261
+ count = 0
262
+ global PREVIEW
263
+ for batch_pred, batch_matr in FACE_SWAPPER.batch_forward(whole_frame_list, analysed_targets, analysed_sources):
264
+ preds.extend(batch_pred)
265
+ matrs.extend(batch_matr)
266
+ EMPTY_CACHE()
267
+ count += 1
268
+
269
+ if USE_CUDA:
270
+ image_grid = create_image_grid(batch_pred, size=128)
271
+ PREVIEW = image_grid[:, :, ::-1]
272
+ yield f"### \n ⌛ Generating face Batch {count}", *ui_before()
273
+
274
+ ## ------------------------------ FACE ENHANCEMENT ------------------------------
275
+
276
+ generated_len = len(preds)
277
+ if face_enhancer_name != "NONE":
278
+ yield f"### \n ⌛ Upscaling faces with {face_enhancer_name}...", *ui_before()
279
+ for idx, pred in tqdm(enumerate(preds), total=generated_len, desc=f"Upscaling with {face_enhancer_name}"):
280
+ enhancer_model, enhancer_model_runner = FACE_ENHANCER
281
+ pred = enhancer_model_runner(pred, enhancer_model)
282
+ preds[idx] = cv2.resize(pred, (512,512))
283
+ EMPTY_CACHE()
284
+
285
+ ## ------------------------------ FACE PARSING ------------------------------
286
+
287
+ if enable_face_parser:
288
+ yield "### \n ⌛ Face-parsing mask...", *ui_before()
289
+ masks = []
290
+ count = 0
291
+ for batch_mask in get_parsed_mask(FACE_PARSER, preds, classes=includes, device=device, batch_size=BATCH_SIZE, softness=int(mask_soft_iterations)):
292
+ masks.append(batch_mask)
293
+ EMPTY_CACHE()
294
+ count += 1
295
+
296
+ if len(batch_mask) > 1:
297
+ image_grid = create_image_grid(batch_mask, size=128)
298
+ PREVIEW = image_grid[:, :, ::-1]
299
+ yield f"### \n ⌛ Face parsing Batch {count}", *ui_before()
300
+ masks = np.concatenate(masks, axis=0) if len(masks) >= 1 else masks
301
+ else:
302
+ masks = [None] * generated_len
303
+
304
+ ## ------------------------------ SPLIT LIST ------------------------------
305
+
306
+ split_preds = split_list_by_lengths(preds, num_faces_per_frame)
307
+ del preds
308
+ split_matrs = split_list_by_lengths(matrs, num_faces_per_frame)
309
+ del matrs
310
+ split_masks = split_list_by_lengths(masks, num_faces_per_frame)
311
+ del masks
312
+
313
+ ## ------------------------------ PASTE-BACK ------------------------------
314
+
315
+ yield "### \n ⌛ Pasting back...", *ui_before()
316
+ def post_process(frame_idx, frame_img, split_preds, split_matrs, split_masks, enable_laplacian_blend, crop_mask, blur_amount, erode_amount):
317
+ whole_img_path = frame_img
318
+ whole_img = cv2.imread(whole_img_path)
319
+ blend_method = 'laplacian' if enable_laplacian_blend else 'linear'
320
+ for p, m, mask in zip(split_preds[frame_idx], split_matrs[frame_idx], split_masks[frame_idx]):
321
+ p = cv2.resize(p, (512,512))
322
+ mask = cv2.resize(mask, (512,512)) if mask is not None else None
323
+ m /= 0.25
324
+ whole_img = paste_to_whole(p, whole_img, m, mask=mask, crop_mask=crop_mask, blend_method=blend_method, blur_amount=blur_amount, erode_amount=erode_amount)
325
+ cv2.imwrite(whole_img_path, whole_img)
326
+
327
+ def concurrent_post_process(image_sequence, *args):
328
+ with concurrent.futures.ThreadPoolExecutor() as executor:
329
+ futures = []
330
+ for idx, frame_img in enumerate(image_sequence):
331
+ future = executor.submit(post_process, idx, frame_img, *args)
332
+ futures.append(future)
333
+
334
+ for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Pasting back"):
335
+ result = future.result()
336
+
337
+ concurrent_post_process(
338
+ image_sequence,
339
+ split_preds,
340
+ split_matrs,
341
+ split_masks,
342
+ enable_laplacian_blend,
343
+ crop_mask,
344
+ blur_amount,
345
+ erode_amount
346
+ )
347
+
348
+
349
+ ## ------------------------------ IMAGE ------------------------------
350
+
351
+ if input_type == "Image":
352
+ target = cv2.imread(image_path)
353
+ output_file = os.path.join(output_path, output_name + ".png")
354
+ cv2.imwrite(output_file, target)
355
+
356
+ for info_update in swap_process([output_file]):
357
+ yield info_update
358
+
359
+ OUTPUT_FILE = output_file
360
+ WORKSPACE = output_path
361
+ PREVIEW = cv2.imread(output_file)[:, :, ::-1]
362
+
363
+ yield get_finsh_text(start_time), *ui_after()
364
+
365
+ ## ------------------------------ VIDEO ------------------------------
366
+
367
+ elif input_type == "Video":
368
+ temp_path = os.path.join(output_path, output_name, "sequence")
369
+ os.makedirs(temp_path, exist_ok=True)
370
+
371
+ yield "### \n ⌛ Extracting video frames...", *ui_before()
372
+ image_sequence = []
373
+ cap = cv2.VideoCapture(video_path)
374
+ curr_idx = 0
375
+ while True:
376
+ ret, frame = cap.read()
377
+ if not ret:break
378
+ frame_path = os.path.join(temp_path, f"frame_{curr_idx}.jpg")
379
+ cv2.imwrite(frame_path, frame)
380
+ image_sequence.append(frame_path)
381
+ curr_idx += 1
382
+ cap.release()
383
+ cv2.destroyAllWindows()
384
+
385
+ for info_update in swap_process(image_sequence):
386
+ yield info_update
387
+
388
+ yield "### \n ⌛ Merging sequence...", *ui_before()
389
+ output_video_path = os.path.join(output_path, output_name + ".mp4")
390
+ merge_img_sequence_from_ref(video_path, image_sequence, output_video_path)
391
+
392
+ if os.path.exists(temp_path) and not keep_output_sequence:
393
+ yield "### \n ⌛ Removing temporary files...", *ui_before()
394
+ shutil.rmtree(temp_path)
395
+
396
+ WORKSPACE = output_path
397
+ OUTPUT_FILE = output_video_path
398
+
399
+ yield get_finsh_text(start_time), *ui_after_vid()
400
+
401
+ ## ------------------------------ DIRECTORY ------------------------------
402
+
403
+ elif input_type == "Directory":
404
+ extensions = ["jpg", "jpeg", "png", "bmp", "tiff", "ico", "webp"]
405
+ temp_path = os.path.join(output_path, output_name)
406
+ if os.path.exists(temp_path):
407
+ shutil.rmtree(temp_path)
408
+ os.mkdir(temp_path)
409
+
410
+ file_paths =[]
411
+ for file_path in glob.glob(os.path.join(directory_path, "*")):
412
+ if any(file_path.lower().endswith(ext) for ext in extensions):
413
+ img = cv2.imread(file_path)
414
+ new_file_path = os.path.join(temp_path, os.path.basename(file_path))
415
+ cv2.imwrite(new_file_path, img)
416
+ file_paths.append(new_file_path)
417
+
418
+ for info_update in swap_process(file_paths):
419
+ yield info_update
420
+
421
+ PREVIEW = cv2.imread(file_paths[-1])[:, :, ::-1]
422
+ WORKSPACE = temp_path
423
+ OUTPUT_FILE = file_paths[-1]
424
+
425
+ yield get_finsh_text(start_time), *ui_after()
426
+
427
+ ## ------------------------------ STREAM ------------------------------
428
+
429
+ elif input_type == "Stream":
430
+ pass
431
+
432
+
433
+ ## ------------------------------ GRADIO FUNC ------------------------------
434
+
435
+
436
+ def update_radio(value):
437
+ if value == "Image":
438
+ return (
439
+ gr.update(visible=True),
440
+ gr.update(visible=False),
441
+ gr.update(visible=False),
442
+ )
443
+ elif value == "Video":
444
+ return (
445
+ gr.update(visible=False),
446
+ gr.update(visible=True),
447
+ gr.update(visible=False),
448
+ )
449
+ elif value == "Directory":
450
+ return (
451
+ gr.update(visible=False),
452
+ gr.update(visible=False),
453
+ gr.update(visible=True),
454
+ )
455
+ elif value == "Stream":
456
+ return (
457
+ gr.update(visible=False),
458
+ gr.update(visible=False),
459
+ gr.update(visible=True),
460
+ )
461
+
462
+
463
+ def swap_option_changed(value):
464
+ if value.startswith("Age"):
465
+ return (
466
+ gr.update(visible=True),
467
+ gr.update(visible=False),
468
+ gr.update(visible=True),
469
+ )
470
+ elif value == "Specific Face":
471
+ return (
472
+ gr.update(visible=False),
473
+ gr.update(visible=True),
474
+ gr.update(visible=False),
475
+ )
476
+ return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
477
+
478
+
479
+ def video_changed(video_path):
480
+ sliders_update = gr.Slider.update
481
+ button_update = gr.Button.update
482
+ number_update = gr.Number.update
483
+
484
+ if video_path is None:
485
+ return (
486
+ sliders_update(minimum=0, maximum=0, value=0),
487
+ sliders_update(minimum=1, maximum=1, value=1),
488
+ number_update(value=1),
489
+ )
490
+ try:
491
+ clip = VideoFileClip(video_path)
492
+ fps = clip.fps
493
+ total_frames = clip.reader.nframes
494
+ clip.close()
495
+ return (
496
+ sliders_update(minimum=0, maximum=total_frames, value=0, interactive=True),
497
+ sliders_update(
498
+ minimum=0, maximum=total_frames, value=total_frames, interactive=True
499
+ ),
500
+ number_update(value=fps),
501
+ )
502
+ except:
503
+ return (
504
+ sliders_update(value=0),
505
+ sliders_update(value=0),
506
+ number_update(value=1),
507
+ )
508
+
509
+
510
+ def analyse_settings_changed(detect_condition, detection_size, detection_threshold):
511
+ yield "### \n ⌛ Applying new values..."
512
+ global FACE_ANALYSER
513
+ global DETECT_CONDITION
514
+ DETECT_CONDITION = detect_condition
515
+ FACE_ANALYSER = insightface.app.FaceAnalysis(name="buffalo_l", providers=PROVIDER)
516
+ FACE_ANALYSER.prepare(
517
+ ctx_id=0,
518
+ det_size=(int(detection_size), int(detection_size)),
519
+ det_thresh=float(detection_threshold),
520
+ )
521
+ yield f"### \n ✔️ Applied detect condition:{detect_condition}, detection size: {detection_size}, detection threshold: {detection_threshold}"
522
+
523
+
524
+ def stop_running():
525
+ global STREAMER
526
+ if hasattr(STREAMER, "stop"):
527
+ STREAMER.stop()
528
+ STREAMER = None
529
+ return "Cancelled"
530
+
531
+
532
+ def slider_changed(show_frame, video_path, frame_index):
533
+ if not show_frame:
534
+ return None, None
535
+ if video_path is None:
536
+ return None, None
537
+ clip = VideoFileClip(video_path)
538
+ frame = clip.get_frame(frame_index / clip.fps)
539
+ frame_array = np.array(frame)
540
+ clip.close()
541
+ return gr.Image.update(value=frame_array, visible=True), gr.Video.update(
542
+ visible=False
543
+ )
544
+
545
+
546
+ def trim_and_reload(video_path, output_path, output_name, start_frame, stop_frame):
547
+ yield video_path, f"### \n ⌛ Trimming video frame {start_frame} to {stop_frame}..."
548
+ try:
549
+ output_path = os.path.join(output_path, output_name)
550
+ trimmed_video = trim_video(video_path, output_path, start_frame, stop_frame)
551
+ yield trimmed_video, "### \n ✔️ Video trimmed and reloaded."
552
+ except Exception as e:
553
+ print(e)
554
+ yield video_path, "### \n ❌ Video trimming failed. See console for more info."
555
+
556
+
557
+ ## ------------------------------ GRADIO GUI ------------------------------
558
+
559
+ css = """
560
+ footer{display:none !important}
561
+ """
562
+
563
+ with gr.Blocks(css=css) as interface:
564
+ gr.Markdown("# 🗿 Swap Mukham")
565
+ gr.Markdown("### Face swap app based on insightface inswapper.")
566
+ with gr.Row():
567
+ with gr.Row():
568
+ with gr.Column(scale=0.4):
569
+ with gr.Tab("📄 Swap Condition"):
570
+ swap_option = gr.Dropdown(
571
+ swap_options_list,
572
+ info="Choose which face or faces in the target image to swap.",
573
+ multiselect=False,
574
+ show_label=False,
575
+ value=swap_options_list[0],
576
+ interactive=True,
577
+ )
578
+ age = gr.Number(
579
+ value=25, label="Value", interactive=True, visible=False
580
+ )
581
+
582
+ with gr.Tab("🎚️ Detection Settings"):
583
+ detect_condition_dropdown = gr.Dropdown(
584
+ detect_conditions,
585
+ label="Condition",
586
+ value=DETECT_CONDITION,
587
+ interactive=True,
588
+ info="This condition is only used when multiple faces are detected on source or specific image.",
589
+ )
590
+ detection_size = gr.Number(
591
+ label="Detection Size", value=DETECT_SIZE, interactive=True
592
+ )
593
+ detection_threshold = gr.Number(
594
+ label="Detection Threshold",
595
+ value=DETECT_THRESH,
596
+ interactive=True,
597
+ )
598
+ apply_detection_settings = gr.Button("Apply settings")
599
+
600
+ with gr.Tab("📤 Output Settings"):
601
+ output_directory = gr.Text(
602
+ label="Output Directory",
603
+ value=DEF_OUTPUT_PATH,
604
+ interactive=True,
605
+ )
606
+ output_name = gr.Text(
607
+ label="Output Name", value="Result", interactive=True
608
+ )
609
+ keep_output_sequence = gr.Checkbox(
610
+ label="Keep output sequence", value=False, interactive=True
611
+ )
612
+
613
+ with gr.Tab("🪄 Other Settings"):
614
+ face_scale = gr.Slider(
615
+ label="Face Scale",
616
+ minimum=0,
617
+ maximum=2,
618
+ value=1,
619
+ interactive=True,
620
+ )
621
+
622
+ face_enhancer_name = gr.Dropdown(
623
+ FACE_ENHANCER_LIST, label="Face Enhancer", value="NONE", multiselect=False, interactive=True
624
+ )
625
+
626
+ with gr.Accordion("Advanced Mask", open=False):
627
+ enable_face_parser_mask = gr.Checkbox(
628
+ label="Enable Face Parsing",
629
+ value=False,
630
+ interactive=True,
631
+ )
632
+
633
+ mask_include = gr.Dropdown(
634
+ mask_regions.keys(),
635
+ value=MASK_INCLUDE,
636
+ multiselect=True,
637
+ label="Include",
638
+ interactive=True,
639
+ )
640
+ mask_soft_kernel = gr.Number(
641
+ label="Soft Erode Kernel",
642
+ value=MASK_SOFT_KERNEL,
643
+ minimum=3,
644
+ interactive=True,
645
+ visible = False
646
+ )
647
+ mask_soft_iterations = gr.Number(
648
+ label="Soft Erode Iterations",
649
+ value=MASK_SOFT_ITERATIONS,
650
+ minimum=0,
651
+ interactive=True,
652
+
653
+ )
654
+
655
+
656
+ with gr.Accordion("Crop Mask", open=False):
657
+ crop_top = gr.Slider(label="Top", minimum=0, maximum=511, value=0, step=1, interactive=True)
658
+ crop_bott = gr.Slider(label="Bottom", minimum=0, maximum=511, value=511, step=1, interactive=True)
659
+ crop_left = gr.Slider(label="Left", minimum=0, maximum=511, value=0, step=1, interactive=True)
660
+ crop_right = gr.Slider(label="Right", minimum=0, maximum=511, value=511, step=1, interactive=True)
661
+
662
+
663
+ erode_amount = gr.Slider(
664
+ label="Mask Erode",
665
+ minimum=0,
666
+ maximum=1,
667
+ value=MASK_ERODE_AMOUNT,
668
+ step=0.05,
669
+ interactive=True,
670
+ )
671
+
672
+ blur_amount = gr.Slider(
673
+ label="Mask Blur",
674
+ minimum=0,
675
+ maximum=1,
676
+ value=MASK_BLUR_AMOUNT,
677
+ step=0.05,
678
+ interactive=True,
679
+ )
680
+
681
+ enable_laplacian_blend = gr.Checkbox(
682
+ label="Laplacian Blending",
683
+ value=True,
684
+ interactive=True,
685
+ )
686
+
687
+
688
+ source_image_input = gr.Image(
689
+ label="Source face", type="filepath", interactive=True
690
+ )
691
+
692
+ with gr.Box(visible=False) as specific_face:
693
+ for i in range(NUM_OF_SRC_SPECIFIC):
694
+ idx = i + 1
695
+ code = "\n"
696
+ code += f"with gr.Tab(label='({idx})'):"
697
+ code += "\n\twith gr.Row():"
698
+ code += f"\n\t\tsrc{idx} = gr.Image(interactive=True, type='numpy', label='Source Face {idx}')"
699
+ code += f"\n\t\ttrg{idx} = gr.Image(interactive=True, type='numpy', label='Specific Face {idx}')"
700
+ exec(code)
701
+
702
+ distance_slider = gr.Slider(
703
+ minimum=0,
704
+ maximum=2,
705
+ value=0.6,
706
+ interactive=True,
707
+ label="Distance",
708
+ info="Lower distance is more similar and higher distance is less similar to the target face.",
709
+ )
710
+
711
+ with gr.Group():
712
+ input_type = gr.Radio(
713
+ ["Image", "Video"],
714
+ label="Target Type",
715
+ value="Image",
716
+ )
717
+
718
+ with gr.Box(visible=True) as input_image_group:
719
+ image_input = gr.Image(
720
+ label="Target Image", interactive=True, type="filepath"
721
+ )
722
+
723
+ with gr.Box(visible=False) as input_video_group:
724
+ vid_widget = gr.Video if USE_COLAB else gr.Text
725
+ video_input = gr.Video(
726
+ label="Target Video", interactive=True
727
+ )
728
+ with gr.Accordion("✂️ Trim video", open=False):
729
+ with gr.Column():
730
+ with gr.Row():
731
+ set_slider_range_btn = gr.Button(
732
+ "Set frame range", interactive=True
733
+ )
734
+ show_trim_preview_btn = gr.Checkbox(
735
+ label="Show frame when slider change",
736
+ value=True,
737
+ interactive=True,
738
+ )
739
+
740
+ video_fps = gr.Number(
741
+ value=30,
742
+ interactive=False,
743
+ label="Fps",
744
+ visible=False,
745
+ )
746
+ start_frame = gr.Slider(
747
+ minimum=0,
748
+ maximum=1,
749
+ value=0,
750
+ step=1,
751
+ interactive=True,
752
+ label="Start Frame",
753
+ info="",
754
+ )
755
+ end_frame = gr.Slider(
756
+ minimum=0,
757
+ maximum=1,
758
+ value=1,
759
+ step=1,
760
+ interactive=True,
761
+ label="End Frame",
762
+ info="",
763
+ )
764
+ trim_and_reload_btn = gr.Button(
765
+ "Trim and Reload", interactive=True
766
+ )
767
+
768
+ with gr.Box(visible=False) as input_directory_group:
769
+ direc_input = gr.Text(label="Path", interactive=True)
770
+
771
+ with gr.Column(scale=0.6):
772
+ info = gr.Markdown(value="...")
773
+
774
+ with gr.Row():
775
+ swap_button = gr.Button("✨ Swap", variant="primary")
776
+ cancel_button = gr.Button("⛔ Cancel")
777
+
778
+ preview_image = gr.Image(label="Output", interactive=False)
779
+ preview_video = gr.Video(
780
+ label="Output", interactive=False, visible=False
781
+ )
782
+
783
+ with gr.Row():
784
+ output_directory_button = gr.Button(
785
+ "📂", interactive=False, visible=False
786
+ )
787
+ output_video_button = gr.Button(
788
+ "🎬", interactive=False, visible=False
789
+ )
790
+
791
+ with gr.Box():
792
+ with gr.Row():
793
+ gr.Markdown(
794
+ "### [🤝 Sponsor](https://github.com/sponsors/harisreedhar)"
795
+ )
796
+ gr.Markdown(
797
+ "### [👨‍💻 Source code](https://github.com/harisreedhar/Swap-Mukham)"
798
+ )
799
+ gr.Markdown(
800
+ "### [⚠️ Disclaimer](https://github.com/harisreedhar/Swap-Mukham#disclaimer)"
801
+ )
802
+ gr.Markdown(
803
+ "### [🌐 Run in Colab](https://colab.research.google.com/github/harisreedhar/Swap-Mukham/blob/main/swap_mukham_colab.ipynb)"
804
+ )
805
+ gr.Markdown(
806
+ "### [🤗 Acknowledgements](https://github.com/harisreedhar/Swap-Mukham#acknowledgements)"
807
+ )
808
+
809
+ ## ------------------------------ GRADIO EVENTS ------------------------------
810
+
811
+ set_slider_range_event = set_slider_range_btn.click(
812
+ video_changed,
813
+ inputs=[video_input],
814
+ outputs=[start_frame, end_frame, video_fps],
815
+ )
816
+
817
+ trim_and_reload_event = trim_and_reload_btn.click(
818
+ fn=trim_and_reload,
819
+ inputs=[video_input, output_directory, output_name, start_frame, end_frame],
820
+ outputs=[video_input, info],
821
+ )
822
+
823
+ start_frame_event = start_frame.release(
824
+ fn=slider_changed,
825
+ inputs=[show_trim_preview_btn, video_input, start_frame],
826
+ outputs=[preview_image, preview_video],
827
+ show_progress=True,
828
+ )
829
+
830
+ end_frame_event = end_frame.release(
831
+ fn=slider_changed,
832
+ inputs=[show_trim_preview_btn, video_input, end_frame],
833
+ outputs=[preview_image, preview_video],
834
+ show_progress=True,
835
+ )
836
+
837
+ input_type.change(
838
+ update_radio,
839
+ inputs=[input_type],
840
+ outputs=[input_image_group, input_video_group, input_directory_group],
841
+ )
842
+ swap_option.change(
843
+ swap_option_changed,
844
+ inputs=[swap_option],
845
+ outputs=[age, specific_face, source_image_input],
846
+ )
847
+
848
+ apply_detection_settings.click(
849
+ analyse_settings_changed,
850
+ inputs=[detect_condition_dropdown, detection_size, detection_threshold],
851
+ outputs=[info],
852
+ )
853
+
854
+ src_specific_inputs = []
855
+ gen_variable_txt = ",".join(
856
+ [f"src{i+1}" for i in range(NUM_OF_SRC_SPECIFIC)]
857
+ + [f"trg{i+1}" for i in range(NUM_OF_SRC_SPECIFIC)]
858
+ )
859
+ exec(f"src_specific_inputs = ({gen_variable_txt})")
860
+ swap_inputs = [
861
+ input_type,
862
+ image_input,
863
+ video_input,
864
+ direc_input,
865
+ source_image_input,
866
+ output_directory,
867
+ output_name,
868
+ keep_output_sequence,
869
+ swap_option,
870
+ age,
871
+ distance_slider,
872
+ face_enhancer_name,
873
+ enable_face_parser_mask,
874
+ mask_include,
875
+ mask_soft_kernel,
876
+ mask_soft_iterations,
877
+ blur_amount,
878
+ erode_amount,
879
+ face_scale,
880
+ enable_laplacian_blend,
881
+ crop_top,
882
+ crop_bott,
883
+ crop_left,
884
+ crop_right,
885
+ *src_specific_inputs,
886
+ ]
887
+
888
+ swap_outputs = [
889
+ info,
890
+ preview_image,
891
+ output_directory_button,
892
+ output_video_button,
893
+ preview_video,
894
+ ]
895
+
896
+ swap_event = swap_button.click(
897
+ fn=process, inputs=swap_inputs, outputs=swap_outputs, show_progress=True
898
+ )
899
+
900
+ cancel_button.click(
901
+ fn=stop_running,
902
+ inputs=None,
903
+ outputs=[info],
904
+ cancels=[
905
+ swap_event,
906
+ trim_and_reload_event,
907
+ set_slider_range_event,
908
+ start_frame_event,
909
+ end_frame_event,
910
+ ],
911
+ show_progress=True,
912
+ )
913
+ output_directory_button.click(
914
+ lambda: open_directory(path=WORKSPACE), inputs=None, outputs=None
915
+ )
916
+ output_video_button.click(
917
+ lambda: open_directory(path=OUTPUT_FILE), inputs=None, outputs=None
918
+ )
919
+
920
+ if __name__ == "__main__":
921
+ if USE_COLAB:
922
+ print("Running in colab mode")
923
+
924
+ interface.queue(concurrency_count=2, max_size=20).launch(share=USE_COLAB)
assets/images/logo.png ADDED
assets/pretrained_models/79999_iter.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:468e13ca13a9b43cc0881a9f99083a430e9c0a38abd935431d1c28ee94b26567
3
+ size 53289463
assets/pretrained_models/GFPGANv1.4.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e2cd4703ab14f4d01fd1383a8a8b266f9a5833dacee8e6a79d3bf21a1b6be5ad
3
+ size 348632874
assets/pretrained_models/RealESRGAN_x2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c830d067d54fc767b9543a8432f36d91bc2de313584e8bbfe4ac26a47339e899
3
+ size 67061725
assets/pretrained_models/RealESRGAN_x4.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aa00f09ad753d88576b21ed977e97d634976377031b178acc3b5b238df463400
3
+ size 67040989
assets/pretrained_models/RealESRGAN_x8.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8b72fb469d12f05a4770813d2603eb1b550f40df6fb8b37d6c7bc2db3d2bff5e
3
+ size 67189359
assets/pretrained_models/codeformer.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:91e7e881c5001fea4a535e8f96eaeaa672d30c963a678a3e27f0429a6620f57a
3
+ size 376821950
assets/pretrained_models/inswapper_128.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e4a3f08c753cb72d04e10aa0f7dbe3deebbf39567d4ead6dce08e98aa49e16af
3
+ size 554253681
assets/pretrained_models/open-nsfw.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:864bb37bf8863564b87eb330ab8c785a79a773f4e7c43cb96db52ed8611305fa
3
+ size 23590724
assets/pretrained_models/readme.md ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ ## Downolad these models here
2
+ - [inswapper_128.onnx](https://huggingface.co/deepinsight/inswapper/resolve/main/inswapper_128.onnx)
3
+ - [GFPGANv1.4.pth](https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth)
4
+ - [79999_iter.pth](https://drive.google.com/open?id=154JgKpzCPW82qINcVieuPH3fZ2e0P812)
face_analyser.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ from tqdm import tqdm
5
+ from utils import scale_bbox_from_center
6
+
7
+ detect_conditions = [
8
+ "best detection",
9
+ "left most",
10
+ "right most",
11
+ "top most",
12
+ "bottom most",
13
+ "middle",
14
+ "biggest",
15
+ "smallest",
16
+ ]
17
+
18
+ swap_options_list = [
19
+ "All Face",
20
+ "Specific Face",
21
+ "Age less than",
22
+ "Age greater than",
23
+ "All Male",
24
+ "All Female",
25
+ "Left Most",
26
+ "Right Most",
27
+ "Top Most",
28
+ "Bottom Most",
29
+ "Middle",
30
+ "Biggest",
31
+ "Smallest",
32
+ ]
33
+
34
+ def get_single_face(faces, method="best detection"):
35
+ total_faces = len(faces)
36
+ if total_faces == 1:
37
+ return faces[0]
38
+
39
+ print(f"{total_faces} face detected. Using {method} face.")
40
+ if method == "best detection":
41
+ return sorted(faces, key=lambda face: face["det_score"])[-1]
42
+ elif method == "left most":
43
+ return sorted(faces, key=lambda face: face["bbox"][0])[0]
44
+ elif method == "right most":
45
+ return sorted(faces, key=lambda face: face["bbox"][0])[-1]
46
+ elif method == "top most":
47
+ return sorted(faces, key=lambda face: face["bbox"][1])[0]
48
+ elif method == "bottom most":
49
+ return sorted(faces, key=lambda face: face["bbox"][1])[-1]
50
+ elif method == "middle":
51
+ return sorted(faces, key=lambda face: (
52
+ (face["bbox"][0] + face["bbox"][2]) / 2 - 0.5) ** 2 +
53
+ ((face["bbox"][1] + face["bbox"][3]) / 2 - 0.5) ** 2)[len(faces) // 2]
54
+ elif method == "biggest":
55
+ return sorted(faces, key=lambda face: (face["bbox"][2] - face["bbox"][0]) * (face["bbox"][3] - face["bbox"][1]))[-1]
56
+ elif method == "smallest":
57
+ return sorted(faces, key=lambda face: (face["bbox"][2] - face["bbox"][0]) * (face["bbox"][3] - face["bbox"][1]))[0]
58
+
59
+
60
+ def analyse_face(image, model, return_single_face=True, detect_condition="best detection", scale=1.0):
61
+ faces = model.get(image)
62
+ if scale != 1: # landmark-scale
63
+ for i, face in enumerate(faces):
64
+ landmark = face['kps']
65
+ center = np.mean(landmark, axis=0)
66
+ landmark = center + (landmark - center) * scale
67
+ faces[i]['kps'] = landmark
68
+
69
+ if not return_single_face:
70
+ return faces
71
+
72
+ return get_single_face(faces, method=detect_condition)
73
+
74
+
75
+ def cosine_distance(a, b):
76
+ a /= np.linalg.norm(a)
77
+ b /= np.linalg.norm(b)
78
+ return 1 - np.dot(a, b)
79
+
80
+
81
+ def get_analysed_data(face_analyser, image_sequence, source_data, swap_condition="All face", detect_condition="left most", scale=1.0):
82
+ if swap_condition != "Specific Face":
83
+ source_path, age = source_data
84
+ source_image = cv2.imread(source_path)
85
+ analysed_source = analyse_face(source_image, face_analyser, return_single_face=True, detect_condition=detect_condition, scale=scale)
86
+ else:
87
+ analysed_source_specifics = []
88
+ source_specifics, threshold = source_data
89
+ for source, specific in zip(*source_specifics):
90
+ if source is None or specific is None:
91
+ continue
92
+ analysed_source = analyse_face(source, face_analyser, return_single_face=True, detect_condition=detect_condition, scale=scale)
93
+ analysed_specific = analyse_face(specific, face_analyser, return_single_face=True, detect_condition=detect_condition, scale=scale)
94
+ analysed_source_specifics.append([analysed_source, analysed_specific])
95
+
96
+ analysed_target_list = []
97
+ analysed_source_list = []
98
+ whole_frame_eql_list = []
99
+ num_faces_per_frame = []
100
+
101
+ total_frames = len(image_sequence)
102
+ curr_idx = 0
103
+ for curr_idx, frame_path in tqdm(enumerate(image_sequence), total=total_frames, desc="Analysing face data"):
104
+ frame = cv2.imread(frame_path)
105
+ analysed_faces = analyse_face(frame, face_analyser, return_single_face=False, detect_condition=detect_condition, scale=scale)
106
+
107
+ n_faces = 0
108
+ for analysed_face in analysed_faces:
109
+ if swap_condition == "All Face":
110
+ analysed_target_list.append(analysed_face)
111
+ analysed_source_list.append(analysed_source)
112
+ whole_frame_eql_list.append(frame_path)
113
+ n_faces += 1
114
+ elif swap_condition == "Age less than" and analysed_face["age"] < age:
115
+ analysed_target_list.append(analysed_face)
116
+ analysed_source_list.append(analysed_source)
117
+ whole_frame_eql_list.append(frame_path)
118
+ n_faces += 1
119
+ elif swap_condition == "Age greater than" and analysed_face["age"] > age:
120
+ analysed_target_list.append(analysed_face)
121
+ analysed_source_list.append(analysed_source)
122
+ whole_frame_eql_list.append(frame_path)
123
+ n_faces += 1
124
+ elif swap_condition == "All Male" and analysed_face["gender"] == 1:
125
+ analysed_target_list.append(analysed_face)
126
+ analysed_source_list.append(analysed_source)
127
+ whole_frame_eql_list.append(frame_path)
128
+ n_faces += 1
129
+ elif swap_condition == "All Female" and analysed_face["gender"] == 0:
130
+ analysed_target_list.append(analysed_face)
131
+ analysed_source_list.append(analysed_source)
132
+ whole_frame_eql_list.append(frame_path)
133
+ n_faces += 1
134
+ elif swap_condition == "Specific Face":
135
+ for analysed_source, analysed_specific in analysed_source_specifics:
136
+ distance = cosine_distance(analysed_specific["embedding"], analysed_face["embedding"])
137
+ if distance < threshold:
138
+ analysed_target_list.append(analysed_face)
139
+ analysed_source_list.append(analysed_source)
140
+ whole_frame_eql_list.append(frame_path)
141
+ n_faces += 1
142
+
143
+ if swap_condition == "Left Most":
144
+ analysed_face = get_single_face(analysed_faces, method="left most")
145
+ analysed_target_list.append(analysed_face)
146
+ analysed_source_list.append(analysed_source)
147
+ whole_frame_eql_list.append(frame_path)
148
+ n_faces += 1
149
+
150
+ elif swap_condition == "Right Most":
151
+ analysed_face = get_single_face(analysed_faces, method="right most")
152
+ analysed_target_list.append(analysed_face)
153
+ analysed_source_list.append(analysed_source)
154
+ whole_frame_eql_list.append(frame_path)
155
+ n_faces += 1
156
+
157
+ elif swap_condition == "Top Most":
158
+ analysed_face = get_single_face(analysed_faces, method="top most")
159
+ analysed_target_list.append(analysed_face)
160
+ analysed_source_list.append(analysed_source)
161
+ whole_frame_eql_list.append(frame_path)
162
+ n_faces += 1
163
+
164
+ elif swap_condition == "Bottom Most":
165
+ analysed_face = get_single_face(analysed_faces, method="bottom most")
166
+ analysed_target_list.append(analysed_face)
167
+ analysed_source_list.append(analysed_source)
168
+ whole_frame_eql_list.append(frame_path)
169
+ n_faces += 1
170
+
171
+ elif swap_condition == "Middle":
172
+ analysed_face = get_single_face(analysed_faces, method="middle")
173
+ analysed_target_list.append(analysed_face)
174
+ analysed_source_list.append(analysed_source)
175
+ whole_frame_eql_list.append(frame_path)
176
+ n_faces += 1
177
+
178
+ elif swap_condition == "Biggest":
179
+ analysed_face = get_single_face(analysed_faces, method="biggest")
180
+ analysed_target_list.append(analysed_face)
181
+ analysed_source_list.append(analysed_source)
182
+ whole_frame_eql_list.append(frame_path)
183
+ n_faces += 1
184
+
185
+ elif swap_condition == "Smallest":
186
+ analysed_face = get_single_face(analysed_faces, method="smallest")
187
+ analysed_target_list.append(analysed_face)
188
+ analysed_source_list.append(analysed_source)
189
+ whole_frame_eql_list.append(frame_path)
190
+ n_faces += 1
191
+
192
+ num_faces_per_frame.append(n_faces)
193
+
194
+ return analysed_target_list, analysed_source_list, whole_frame_eql_list, num_faces_per_frame
face_enhancer.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import gfpgan
5
+ from PIL import Image
6
+ from upscaler.RealESRGAN import RealESRGAN
7
+ from upscaler.codeformer import CodeFormerEnhancer
8
+
9
+ def gfpgan_runner(img, model):
10
+ _, imgs, _ = model.enhance(img, paste_back=True, has_aligned=True)
11
+ return imgs[0]
12
+
13
+
14
+ def realesrgan_runner(img, model):
15
+ img = model.predict(img)
16
+ return img
17
+
18
+
19
+ def codeformer_runner(img, model):
20
+ img = model.enhance(img)
21
+ return img
22
+
23
+
24
+ supported_enhancers = {
25
+ "CodeFormer": ("./assets/pretrained_models/codeformer.onnx", codeformer_runner),
26
+ "GFPGAN": ("./assets/pretrained_models/GFPGANv1.4.pth", gfpgan_runner),
27
+ "REAL-ESRGAN 2x": ("./assets/pretrained_models/RealESRGAN_x2.pth", realesrgan_runner),
28
+ "REAL-ESRGAN 4x": ("./assets/pretrained_models/RealESRGAN_x4.pth", realesrgan_runner),
29
+ "REAL-ESRGAN 8x": ("./assets/pretrained_models/RealESRGAN_x8.pth", realesrgan_runner)
30
+ }
31
+
32
+ cv2_interpolations = ["LANCZOS4", "CUBIC", "NEAREST"]
33
+
34
+ def get_available_enhancer_names():
35
+ available = []
36
+ for name, data in supported_enhancers.items():
37
+ path = os.path.join(os.path.abspath(os.path.dirname(__file__)), data[0])
38
+ if os.path.exists(path):
39
+ available.append(name)
40
+ return available
41
+
42
+
43
+ def load_face_enhancer_model(name='GFPGAN', device="cpu"):
44
+ assert name in get_available_enhancer_names() + cv2_interpolations, f"Face enhancer {name} unavailable."
45
+ if name in supported_enhancers.keys():
46
+ model_path, model_runner = supported_enhancers.get(name)
47
+ model_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), model_path)
48
+ if name == 'CodeFormer':
49
+ model = CodeFormerEnhancer(model_path=model_path, device=device)
50
+ elif name == 'GFPGAN':
51
+ model = gfpgan.GFPGANer(model_path=model_path, upscale=1, device=device)
52
+ elif name == 'REAL-ESRGAN 2x':
53
+ model = RealESRGAN(device, scale=2)
54
+ model.load_weights(model_path, download=False)
55
+ elif name == 'REAL-ESRGAN 4x':
56
+ model = RealESRGAN(device, scale=4)
57
+ model.load_weights(model_path, download=False)
58
+ elif name == 'REAL-ESRGAN 8x':
59
+ model = RealESRGAN(device, scale=8)
60
+ model.load_weights(model_path, download=False)
61
+ elif name == 'LANCZOS4':
62
+ model = None
63
+ model_runner = lambda img, _: cv2.resize(img, (512,512), interpolation=cv2.INTER_LANCZOS4)
64
+ elif name == 'CUBIC':
65
+ model = None
66
+ model_runner = lambda img, _: cv2.resize(img, (512,512), interpolation=cv2.INTER_CUBIC)
67
+ elif name == 'NEAREST':
68
+ model = None
69
+ model_runner = lambda img, _: cv2.resize(img, (512,512), interpolation=cv2.INTER_NEAREST)
70
+ else:
71
+ model = None
72
+ return (model, model_runner)
face_parsing/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .swap import init_parser, swap_regions, mask_regions, mask_regions_to_list
2
+ from .model import BiSeNet
3
+ from .parse_mask import init_parsing_model, get_parsed_mask, SoftErosion
face_parsing/model.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import torchvision
9
+
10
+ from .resnet import Resnet18
11
+ # from modules.bn import InPlaceABNSync as BatchNorm2d
12
+
13
+
14
+ class ConvBNReLU(nn.Module):
15
+ def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
16
+ super(ConvBNReLU, self).__init__()
17
+ self.conv = nn.Conv2d(in_chan,
18
+ out_chan,
19
+ kernel_size = ks,
20
+ stride = stride,
21
+ padding = padding,
22
+ bias = False)
23
+ self.bn = nn.BatchNorm2d(out_chan)
24
+ self.init_weight()
25
+
26
+ def forward(self, x):
27
+ x = self.conv(x)
28
+ x = F.relu(self.bn(x))
29
+ return x
30
+
31
+ def init_weight(self):
32
+ for ly in self.children():
33
+ if isinstance(ly, nn.Conv2d):
34
+ nn.init.kaiming_normal_(ly.weight, a=1)
35
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
36
+
37
+ class BiSeNetOutput(nn.Module):
38
+ def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
39
+ super(BiSeNetOutput, self).__init__()
40
+ self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
41
+ self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
42
+ self.init_weight()
43
+
44
+ def forward(self, x):
45
+ x = self.conv(x)
46
+ x = self.conv_out(x)
47
+ return x
48
+
49
+ def init_weight(self):
50
+ for ly in self.children():
51
+ if isinstance(ly, nn.Conv2d):
52
+ nn.init.kaiming_normal_(ly.weight, a=1)
53
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
54
+
55
+ def get_params(self):
56
+ wd_params, nowd_params = [], []
57
+ for name, module in self.named_modules():
58
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
59
+ wd_params.append(module.weight)
60
+ if not module.bias is None:
61
+ nowd_params.append(module.bias)
62
+ elif isinstance(module, nn.BatchNorm2d):
63
+ nowd_params += list(module.parameters())
64
+ return wd_params, nowd_params
65
+
66
+
67
+ class AttentionRefinementModule(nn.Module):
68
+ def __init__(self, in_chan, out_chan, *args, **kwargs):
69
+ super(AttentionRefinementModule, self).__init__()
70
+ self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
71
+ self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)
72
+ self.bn_atten = nn.BatchNorm2d(out_chan)
73
+ self.sigmoid_atten = nn.Sigmoid()
74
+ self.init_weight()
75
+
76
+ def forward(self, x):
77
+ feat = self.conv(x)
78
+ atten = F.avg_pool2d(feat, feat.size()[2:])
79
+ atten = self.conv_atten(atten)
80
+ atten = self.bn_atten(atten)
81
+ atten = self.sigmoid_atten(atten)
82
+ out = torch.mul(feat, atten)
83
+ return out
84
+
85
+ def init_weight(self):
86
+ for ly in self.children():
87
+ if isinstance(ly, nn.Conv2d):
88
+ nn.init.kaiming_normal_(ly.weight, a=1)
89
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
90
+
91
+
92
+ class ContextPath(nn.Module):
93
+ def __init__(self, *args, **kwargs):
94
+ super(ContextPath, self).__init__()
95
+ self.resnet = Resnet18()
96
+ self.arm16 = AttentionRefinementModule(256, 128)
97
+ self.arm32 = AttentionRefinementModule(512, 128)
98
+ self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
99
+ self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
100
+ self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
101
+
102
+ self.init_weight()
103
+
104
+ def forward(self, x):
105
+ H0, W0 = x.size()[2:]
106
+ feat8, feat16, feat32 = self.resnet(x)
107
+ H8, W8 = feat8.size()[2:]
108
+ H16, W16 = feat16.size()[2:]
109
+ H32, W32 = feat32.size()[2:]
110
+
111
+ avg = F.avg_pool2d(feat32, feat32.size()[2:])
112
+ avg = self.conv_avg(avg)
113
+ avg_up = F.interpolate(avg, (H32, W32), mode='nearest')
114
+
115
+ feat32_arm = self.arm32(feat32)
116
+ feat32_sum = feat32_arm + avg_up
117
+ feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
118
+ feat32_up = self.conv_head32(feat32_up)
119
+
120
+ feat16_arm = self.arm16(feat16)
121
+ feat16_sum = feat16_arm + feat32_up
122
+ feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
123
+ feat16_up = self.conv_head16(feat16_up)
124
+
125
+ return feat8, feat16_up, feat32_up # x8, x8, x16
126
+
127
+ def init_weight(self):
128
+ for ly in self.children():
129
+ if isinstance(ly, nn.Conv2d):
130
+ nn.init.kaiming_normal_(ly.weight, a=1)
131
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
132
+
133
+ def get_params(self):
134
+ wd_params, nowd_params = [], []
135
+ for name, module in self.named_modules():
136
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
137
+ wd_params.append(module.weight)
138
+ if not module.bias is None:
139
+ nowd_params.append(module.bias)
140
+ elif isinstance(module, nn.BatchNorm2d):
141
+ nowd_params += list(module.parameters())
142
+ return wd_params, nowd_params
143
+
144
+
145
+ ### This is not used, since I replace this with the resnet feature with the same size
146
+ class SpatialPath(nn.Module):
147
+ def __init__(self, *args, **kwargs):
148
+ super(SpatialPath, self).__init__()
149
+ self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)
150
+ self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
151
+ self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
152
+ self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)
153
+ self.init_weight()
154
+
155
+ def forward(self, x):
156
+ feat = self.conv1(x)
157
+ feat = self.conv2(feat)
158
+ feat = self.conv3(feat)
159
+ feat = self.conv_out(feat)
160
+ return feat
161
+
162
+ def init_weight(self):
163
+ for ly in self.children():
164
+ if isinstance(ly, nn.Conv2d):
165
+ nn.init.kaiming_normal_(ly.weight, a=1)
166
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
167
+
168
+ def get_params(self):
169
+ wd_params, nowd_params = [], []
170
+ for name, module in self.named_modules():
171
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
172
+ wd_params.append(module.weight)
173
+ if not module.bias is None:
174
+ nowd_params.append(module.bias)
175
+ elif isinstance(module, nn.BatchNorm2d):
176
+ nowd_params += list(module.parameters())
177
+ return wd_params, nowd_params
178
+
179
+
180
+ class FeatureFusionModule(nn.Module):
181
+ def __init__(self, in_chan, out_chan, *args, **kwargs):
182
+ super(FeatureFusionModule, self).__init__()
183
+ self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
184
+ self.conv1 = nn.Conv2d(out_chan,
185
+ out_chan//4,
186
+ kernel_size = 1,
187
+ stride = 1,
188
+ padding = 0,
189
+ bias = False)
190
+ self.conv2 = nn.Conv2d(out_chan//4,
191
+ out_chan,
192
+ kernel_size = 1,
193
+ stride = 1,
194
+ padding = 0,
195
+ bias = False)
196
+ self.relu = nn.ReLU(inplace=True)
197
+ self.sigmoid = nn.Sigmoid()
198
+ self.init_weight()
199
+
200
+ def forward(self, fsp, fcp):
201
+ fcat = torch.cat([fsp, fcp], dim=1)
202
+ feat = self.convblk(fcat)
203
+ atten = F.avg_pool2d(feat, feat.size()[2:])
204
+ atten = self.conv1(atten)
205
+ atten = self.relu(atten)
206
+ atten = self.conv2(atten)
207
+ atten = self.sigmoid(atten)
208
+ feat_atten = torch.mul(feat, atten)
209
+ feat_out = feat_atten + feat
210
+ return feat_out
211
+
212
+ def init_weight(self):
213
+ for ly in self.children():
214
+ if isinstance(ly, nn.Conv2d):
215
+ nn.init.kaiming_normal_(ly.weight, a=1)
216
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
217
+
218
+ def get_params(self):
219
+ wd_params, nowd_params = [], []
220
+ for name, module in self.named_modules():
221
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
222
+ wd_params.append(module.weight)
223
+ if not module.bias is None:
224
+ nowd_params.append(module.bias)
225
+ elif isinstance(module, nn.BatchNorm2d):
226
+ nowd_params += list(module.parameters())
227
+ return wd_params, nowd_params
228
+
229
+
230
+ class BiSeNet(nn.Module):
231
+ def __init__(self, n_classes, *args, **kwargs):
232
+ super(BiSeNet, self).__init__()
233
+ self.cp = ContextPath()
234
+ ## here self.sp is deleted
235
+ self.ffm = FeatureFusionModule(256, 256)
236
+ self.conv_out = BiSeNetOutput(256, 256, n_classes)
237
+ self.conv_out16 = BiSeNetOutput(128, 64, n_classes)
238
+ self.conv_out32 = BiSeNetOutput(128, 64, n_classes)
239
+ self.init_weight()
240
+
241
+ def forward(self, x):
242
+ H, W = x.size()[2:]
243
+ feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature
244
+ feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature
245
+ feat_fuse = self.ffm(feat_sp, feat_cp8)
246
+
247
+ feat_out = self.conv_out(feat_fuse)
248
+ feat_out16 = self.conv_out16(feat_cp8)
249
+ feat_out32 = self.conv_out32(feat_cp16)
250
+
251
+ feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)
252
+ feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True)
253
+ feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True)
254
+ return feat_out, feat_out16, feat_out32
255
+
256
+ def init_weight(self):
257
+ for ly in self.children():
258
+ if isinstance(ly, nn.Conv2d):
259
+ nn.init.kaiming_normal_(ly.weight, a=1)
260
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
261
+
262
+ def get_params(self):
263
+ wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
264
+ for name, child in self.named_children():
265
+ child_wd_params, child_nowd_params = child.get_params()
266
+ if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput):
267
+ lr_mul_wd_params += child_wd_params
268
+ lr_mul_nowd_params += child_nowd_params
269
+ else:
270
+ wd_params += child_wd_params
271
+ nowd_params += child_nowd_params
272
+ return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
273
+
274
+
275
+ if __name__ == "__main__":
276
+ net = BiSeNet(19)
277
+ net.cuda()
278
+ net.eval()
279
+ in_ten = torch.randn(16, 3, 640, 480).cuda()
280
+ out, out16, out32 = net(in_ten)
281
+ print(out.shape)
282
+
283
+ net.get_params()
face_parsing/parse_mask.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import torchvision
4
+ import numpy as np
5
+ import torch.nn as nn
6
+ from PIL import Image
7
+ from tqdm import tqdm
8
+ import torch.nn.functional as F
9
+ import torchvision.transforms as transforms
10
+
11
+ from . model import BiSeNet
12
+
13
+ class SoftErosion(nn.Module):
14
+ def __init__(self, kernel_size=15, threshold=0.6, iterations=1):
15
+ super(SoftErosion, self).__init__()
16
+ r = kernel_size // 2
17
+ self.padding = r
18
+ self.iterations = iterations
19
+ self.threshold = threshold
20
+
21
+ # Create kernel
22
+ y_indices, x_indices = torch.meshgrid(torch.arange(0., kernel_size), torch.arange(0., kernel_size))
23
+ dist = torch.sqrt((x_indices - r) ** 2 + (y_indices - r) ** 2)
24
+ kernel = dist.max() - dist
25
+ kernel /= kernel.sum()
26
+ kernel = kernel.view(1, 1, *kernel.shape)
27
+ self.register_buffer('weight', kernel)
28
+
29
+ def forward(self, x):
30
+ batch_size = x.size(0) # Get the batch size
31
+ output = []
32
+
33
+ for i in tqdm(range(batch_size), desc="Soft-Erosion", leave=False):
34
+ input_tensor = x[i:i+1] # Take one input tensor from the batch
35
+ input_tensor = input_tensor.float() # Convert input to float tensor
36
+ input_tensor = input_tensor.unsqueeze(1) # Add a channel dimension
37
+
38
+ for _ in range(self.iterations - 1):
39
+ input_tensor = torch.min(input_tensor, F.conv2d(input_tensor, weight=self.weight,
40
+ groups=input_tensor.shape[1],
41
+ padding=self.padding))
42
+ input_tensor = F.conv2d(input_tensor, weight=self.weight, groups=input_tensor.shape[1],
43
+ padding=self.padding)
44
+
45
+ mask = input_tensor >= self.threshold
46
+ input_tensor[mask] = 1.0
47
+ input_tensor[~mask] /= input_tensor[~mask].max()
48
+
49
+ input_tensor = input_tensor.squeeze(1) # Remove the extra channel dimension
50
+ output.append(input_tensor.detach().cpu().numpy())
51
+
52
+ return np.array(output)
53
+
54
+ transform = transforms.Compose([
55
+ transforms.Resize((512, 512)),
56
+ transforms.ToTensor(),
57
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
58
+ ])
59
+
60
+
61
+
62
+ def init_parsing_model(model_path, device="cpu"):
63
+ net = BiSeNet(19)
64
+ net.to(device)
65
+ net.load_state_dict(torch.load(model_path))
66
+ net.eval()
67
+ return net
68
+
69
+ def transform_images(imgs):
70
+ tensor_images = torch.stack([transform(Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))) for img in imgs], dim=0)
71
+ return tensor_images
72
+
73
+ def get_parsed_mask(net, imgs, classes=[1, 2, 3, 4, 5, 10, 11, 12, 13], device="cpu", batch_size=8, softness=20):
74
+ if softness > 0:
75
+ smooth_mask = SoftErosion(kernel_size=17, threshold=0.9, iterations=softness).to(device)
76
+
77
+ masks = []
78
+ for i in tqdm(range(0, len(imgs), batch_size), total=len(imgs) // batch_size, desc="Face-parsing"):
79
+ batch_imgs = imgs[i:i + batch_size]
80
+
81
+ tensor_images = transform_images(batch_imgs).to(device)
82
+ with torch.no_grad():
83
+ out = net(tensor_images)[0]
84
+ # parsing = out.argmax(dim=1)
85
+ # arget_classes = torch.tensor(classes).to(device)
86
+ # batch_masks = torch.isin(parsing, target_classes).to(device)
87
+ ## torch.isin was slightly slower in my test, so using np.isin
88
+ parsing = out.argmax(dim=1).detach().cpu().numpy()
89
+ batch_masks = np.isin(parsing, classes).astype('float32')
90
+
91
+ if softness > 0:
92
+ # batch_masks = smooth_mask(batch_masks).transpose(1,0,2,3)[0]
93
+ mask_tensor = torch.from_numpy(batch_masks.copy()).float().to(device)
94
+ batch_masks = smooth_mask(mask_tensor).transpose(1,0,2,3)[0]
95
+
96
+ yield batch_masks
97
+
98
+ #masks.append(batch_masks)
99
+
100
+ #if len(masks) >= 1:
101
+ # masks = np.concatenate(masks, axis=0)
102
+ # masks = np.repeat(np.expand_dims(masks, axis=1), 3, axis=1)
103
+
104
+ # for i, mask in enumerate(masks):
105
+ # cv2.imwrite(f"mask/{i}.jpg", (mask * 255).astype("uint8"))
106
+
107
+ #return masks
face_parsing/resnet.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import torch.utils.model_zoo as modelzoo
8
+
9
+ # from modules.bn import InPlaceABNSync as BatchNorm2d
10
+
11
+ resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
12
+
13
+
14
+ def conv3x3(in_planes, out_planes, stride=1):
15
+ """3x3 convolution with padding"""
16
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
17
+ padding=1, bias=False)
18
+
19
+
20
+ class BasicBlock(nn.Module):
21
+ def __init__(self, in_chan, out_chan, stride=1):
22
+ super(BasicBlock, self).__init__()
23
+ self.conv1 = conv3x3(in_chan, out_chan, stride)
24
+ self.bn1 = nn.BatchNorm2d(out_chan)
25
+ self.conv2 = conv3x3(out_chan, out_chan)
26
+ self.bn2 = nn.BatchNorm2d(out_chan)
27
+ self.relu = nn.ReLU(inplace=True)
28
+ self.downsample = None
29
+ if in_chan != out_chan or stride != 1:
30
+ self.downsample = nn.Sequential(
31
+ nn.Conv2d(in_chan, out_chan,
32
+ kernel_size=1, stride=stride, bias=False),
33
+ nn.BatchNorm2d(out_chan),
34
+ )
35
+
36
+ def forward(self, x):
37
+ residual = self.conv1(x)
38
+ residual = F.relu(self.bn1(residual))
39
+ residual = self.conv2(residual)
40
+ residual = self.bn2(residual)
41
+
42
+ shortcut = x
43
+ if self.downsample is not None:
44
+ shortcut = self.downsample(x)
45
+
46
+ out = shortcut + residual
47
+ out = self.relu(out)
48
+ return out
49
+
50
+
51
+ def create_layer_basic(in_chan, out_chan, bnum, stride=1):
52
+ layers = [BasicBlock(in_chan, out_chan, stride=stride)]
53
+ for i in range(bnum-1):
54
+ layers.append(BasicBlock(out_chan, out_chan, stride=1))
55
+ return nn.Sequential(*layers)
56
+
57
+
58
+ class Resnet18(nn.Module):
59
+ def __init__(self):
60
+ super(Resnet18, self).__init__()
61
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
62
+ bias=False)
63
+ self.bn1 = nn.BatchNorm2d(64)
64
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
65
+ self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
66
+ self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
67
+ self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
68
+ self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
69
+ self.init_weight()
70
+
71
+ def forward(self, x):
72
+ x = self.conv1(x)
73
+ x = F.relu(self.bn1(x))
74
+ x = self.maxpool(x)
75
+
76
+ x = self.layer1(x)
77
+ feat8 = self.layer2(x) # 1/8
78
+ feat16 = self.layer3(feat8) # 1/16
79
+ feat32 = self.layer4(feat16) # 1/32
80
+ return feat8, feat16, feat32
81
+
82
+ def init_weight(self):
83
+ state_dict = modelzoo.load_url(resnet18_url)
84
+ self_state_dict = self.state_dict()
85
+ for k, v in state_dict.items():
86
+ if 'fc' in k: continue
87
+ self_state_dict.update({k: v})
88
+ self.load_state_dict(self_state_dict)
89
+
90
+ def get_params(self):
91
+ wd_params, nowd_params = [], []
92
+ for name, module in self.named_modules():
93
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
94
+ wd_params.append(module.weight)
95
+ if not module.bias is None:
96
+ nowd_params.append(module.bias)
97
+ elif isinstance(module, nn.BatchNorm2d):
98
+ nowd_params += list(module.parameters())
99
+ return wd_params, nowd_params
100
+
101
+
102
+ if __name__ == "__main__":
103
+ net = Resnet18()
104
+ x = torch.randn(16, 3, 224, 224)
105
+ out = net(x)
106
+ print(out[0].size())
107
+ print(out[1].size())
108
+ print(out[2].size())
109
+ net.get_params()
face_parsing/swap.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision.transforms as transforms
5
+ import cv2
6
+ import numpy as np
7
+
8
+ from .model import BiSeNet
9
+
10
+ mask_regions = {
11
+ "Background":0,
12
+ "Skin":1,
13
+ "L-Eyebrow":2,
14
+ "R-Eyebrow":3,
15
+ "L-Eye":4,
16
+ "R-Eye":5,
17
+ "Eye-G":6,
18
+ "L-Ear":7,
19
+ "R-Ear":8,
20
+ "Ear-R":9,
21
+ "Nose":10,
22
+ "Mouth":11,
23
+ "U-Lip":12,
24
+ "L-Lip":13,
25
+ "Neck":14,
26
+ "Neck-L":15,
27
+ "Cloth":16,
28
+ "Hair":17,
29
+ "Hat":18
30
+ }
31
+
32
+ # Borrowed from simswap
33
+ # https://github.com/neuralchen/SimSwap/blob/26c84d2901bd56eda4d5e3c5ca6da16e65dc82a6/util/reverse2original.py#L30
34
+ class SoftErosion(nn.Module):
35
+ def __init__(self, kernel_size=15, threshold=0.6, iterations=1):
36
+ super(SoftErosion, self).__init__()
37
+ r = kernel_size // 2
38
+ self.padding = r
39
+ self.iterations = iterations
40
+ self.threshold = threshold
41
+
42
+ # Create kernel
43
+ y_indices, x_indices = torch.meshgrid(torch.arange(0., kernel_size), torch.arange(0., kernel_size))
44
+ dist = torch.sqrt((x_indices - r) ** 2 + (y_indices - r) ** 2)
45
+ kernel = dist.max() - dist
46
+ kernel /= kernel.sum()
47
+ kernel = kernel.view(1, 1, *kernel.shape)
48
+ self.register_buffer('weight', kernel)
49
+
50
+ def forward(self, x):
51
+ x = x.float()
52
+ for i in range(self.iterations - 1):
53
+ x = torch.min(x, F.conv2d(x, weight=self.weight, groups=x.shape[1], padding=self.padding))
54
+ x = F.conv2d(x, weight=self.weight, groups=x.shape[1], padding=self.padding)
55
+
56
+ mask = x >= self.threshold
57
+ x[mask] = 1.0
58
+ x[~mask] /= x[~mask].max()
59
+
60
+ return x, mask
61
+
62
+ device = "cpu"
63
+
64
+ def init_parser(pth_path, mode="cpu"):
65
+ global device
66
+ device = mode
67
+ n_classes = 19
68
+ net = BiSeNet(n_classes=n_classes)
69
+ if device == "cuda":
70
+ net.cuda()
71
+ net.load_state_dict(torch.load(pth_path))
72
+ else:
73
+ net.load_state_dict(torch.load(pth_path, map_location=torch.device('cpu')))
74
+ net.eval()
75
+ return net
76
+
77
+
78
+ def image_to_parsing(img, net):
79
+ img = cv2.resize(img, (512, 512))
80
+ img = img[:,:,::-1]
81
+ transform = transforms.Compose([
82
+ transforms.ToTensor(),
83
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
84
+ ])
85
+ img = transform(img.copy())
86
+ img = torch.unsqueeze(img, 0)
87
+
88
+ with torch.no_grad():
89
+ img = img.to(device)
90
+ out = net(img)[0]
91
+ parsing = out.squeeze(0).cpu().numpy().argmax(0)
92
+ return parsing
93
+
94
+
95
+ def get_mask(parsing, classes):
96
+ res = parsing == classes[0]
97
+ for val in classes[1:]:
98
+ res += parsing == val
99
+ return res
100
+
101
+
102
+ def swap_regions(source, target, net, smooth_mask, includes=[1,2,3,4,5,10,11,12,13], blur=10):
103
+ parsing = image_to_parsing(source, net)
104
+
105
+ if len(includes) == 0:
106
+ return source, np.zeros_like(source)
107
+
108
+ include_mask = get_mask(parsing, includes)
109
+ mask = np.repeat(include_mask[:, :, np.newaxis], 3, axis=2).astype("float32")
110
+
111
+ if smooth_mask is not None:
112
+ mask_tensor = torch.from_numpy(mask.copy().transpose((2, 0, 1))).float().to(device)
113
+ face_mask_tensor = mask_tensor[0] + mask_tensor[1]
114
+ soft_face_mask_tensor, _ = smooth_mask(face_mask_tensor.unsqueeze_(0).unsqueeze_(0))
115
+ soft_face_mask_tensor.squeeze_()
116
+ mask = np.repeat(soft_face_mask_tensor.cpu().numpy()[:, :, np.newaxis], 3, axis=2)
117
+
118
+ if blur > 0:
119
+ mask = cv2.GaussianBlur(mask, (0, 0), blur)
120
+
121
+ resized_source = cv2.resize((source).astype("float32"), (512, 512))
122
+ resized_target = cv2.resize((target).astype("float32"), (512, 512))
123
+ result = mask * resized_source + (1 - mask) * resized_target
124
+ result = cv2.resize(result.astype("uint8"), (source.shape[1], source.shape[0]))
125
+
126
+ return result
127
+
128
+ def mask_regions_to_list(values):
129
+ out_ids = []
130
+ for value in values:
131
+ if value in mask_regions.keys():
132
+ out_ids.append(mask_regions.get(value))
133
+ return out_ids
face_swapper.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import torch
3
+ import onnx
4
+ import cv2
5
+ import onnxruntime
6
+ import numpy as np
7
+ from tqdm import tqdm
8
+ import torch.nn as nn
9
+ from onnx import numpy_helper
10
+ from skimage import transform as trans
11
+ import torchvision.transforms.functional as F
12
+ import torch.nn.functional as F
13
+ from utils import mask_crop, laplacian_blending
14
+
15
+
16
+ arcface_dst = np.array(
17
+ [[38.2946, 51.6963], [73.5318, 51.5014], [56.0252, 71.7366],
18
+ [41.5493, 92.3655], [70.7299, 92.2041]],
19
+ dtype=np.float32)
20
+
21
+ def estimate_norm(lmk, image_size=112, mode='arcface'):
22
+ assert lmk.shape == (5, 2)
23
+ assert image_size % 112 == 0 or image_size % 128 == 0
24
+ if image_size % 112 == 0:
25
+ ratio = float(image_size) / 112.0
26
+ diff_x = 0
27
+ else:
28
+ ratio = float(image_size) / 128.0
29
+ diff_x = 8.0 * ratio
30
+ dst = arcface_dst * ratio
31
+ dst[:, 0] += diff_x
32
+ tform = trans.SimilarityTransform()
33
+ tform.estimate(lmk, dst)
34
+ M = tform.params[0:2, :]
35
+ return M
36
+
37
+
38
+ def norm_crop2(img, landmark, image_size=112, mode='arcface'):
39
+ M = estimate_norm(landmark, image_size, mode)
40
+ warped = cv2.warpAffine(img, M, (image_size, image_size), borderValue=0.0)
41
+ return warped, M
42
+
43
+
44
+ class Inswapper():
45
+ def __init__(self, model_file=None, batch_size=32, providers=['CPUExecutionProvider']):
46
+ self.model_file = model_file
47
+ self.batch_size = batch_size
48
+
49
+ model = onnx.load(self.model_file)
50
+ graph = model.graph
51
+ self.emap = numpy_helper.to_array(graph.initializer[-1])
52
+
53
+ self.session_options = onnxruntime.SessionOptions()
54
+ self.session = onnxruntime.InferenceSession(self.model_file, sess_options=self.session_options, providers=providers)
55
+
56
+ def forward(self, imgs, latents):
57
+ preds = []
58
+ for img, latent in zip(imgs, latents):
59
+ img = img / 255
60
+ pred = self.session.run(['output'], {'target': img, 'source': latent})[0]
61
+ preds.append(pred)
62
+
63
+ def get(self, imgs, target_faces, source_faces):
64
+ imgs = list(imgs)
65
+
66
+ preds = [None] * len(imgs)
67
+ matrs = [None] * len(imgs)
68
+
69
+ for idx, (img, target_face, source_face) in enumerate(zip(imgs, target_faces, source_faces)):
70
+ matrix, blob, latent = self.prepare_data(img, target_face, source_face)
71
+ pred = self.session.run(['output'], {'target': blob, 'source': latent})[0]
72
+ pred = pred.transpose((0, 2, 3, 1))[0]
73
+ pred = np.clip(255 * pred, 0, 255).astype(np.uint8)[:, :, ::-1]
74
+
75
+ preds[idx] = pred
76
+ matrs[idx] = matrix
77
+
78
+ return (preds, matrs)
79
+
80
+ def prepare_data(self, img, target_face, source_face):
81
+ if isinstance(img, str):
82
+ img = cv2.imread(img)
83
+
84
+ aligned_img, matrix = norm_crop2(img, target_face.kps, 128)
85
+
86
+ blob = cv2.dnn.blobFromImage(aligned_img, 1.0 / 255, (128, 128), (0., 0., 0.), swapRB=True)
87
+
88
+ latent = source_face.normed_embedding.reshape((1, -1))
89
+ latent = np.dot(latent, self.emap)
90
+ latent /= np.linalg.norm(latent)
91
+
92
+ return (matrix, blob, latent)
93
+
94
+ def batch_forward(self, img_list, target_f_list, source_f_list):
95
+ num_samples = len(img_list)
96
+ num_batches = (num_samples + self.batch_size - 1) // self.batch_size
97
+
98
+ for i in tqdm(range(num_batches), desc="Generating face"):
99
+ start_idx = i * self.batch_size
100
+ end_idx = min((i + 1) * self.batch_size, num_samples)
101
+
102
+ batch_img = img_list[start_idx:end_idx]
103
+ batch_target_f = target_f_list[start_idx:end_idx]
104
+ batch_source_f = source_f_list[start_idx:end_idx]
105
+
106
+ batch_pred, batch_matr = self.get(batch_img, batch_target_f, batch_source_f)
107
+
108
+ yield batch_pred, batch_matr
109
+
110
+
111
+ def paste_to_whole(foreground, background, matrix, mask=None, crop_mask=(0,0,0,0), blur_amount=0.1, erode_amount = 0.15, blend_method='linear'):
112
+ inv_matrix = cv2.invertAffineTransform(matrix)
113
+ fg_shape = foreground.shape[:2]
114
+ bg_shape = (background.shape[1], background.shape[0])
115
+ foreground = cv2.warpAffine(foreground, inv_matrix, bg_shape, borderValue=0.0)
116
+
117
+ if mask is None:
118
+ mask = np.full(fg_shape, 1., dtype=np.float32)
119
+ mask = mask_crop(mask, crop_mask)
120
+ mask = cv2.warpAffine(mask, inv_matrix, bg_shape, borderValue=0.0)
121
+ else:
122
+ assert fg_shape == mask.shape[:2], "foreground & mask shape mismatch!"
123
+ mask = mask_crop(mask, crop_mask).astype('float32')
124
+ mask = cv2.warpAffine(mask, inv_matrix, (background.shape[1], background.shape[0]), borderValue=0.0)
125
+
126
+ _mask = mask.copy()
127
+ _mask[_mask > 0.05] = 1.
128
+ non_zero_points = cv2.findNonZero(_mask)
129
+ _, _, w, h = cv2.boundingRect(non_zero_points)
130
+ mask_size = int(np.sqrt(w * h))
131
+
132
+ if erode_amount > 0:
133
+ kernel_size = max(int(mask_size * erode_amount), 1)
134
+ structuring_element = cv2.getStructuringElement(cv2.MORPH_RECT, (kernel_size, kernel_size))
135
+ mask = cv2.erode(mask, structuring_element)
136
+
137
+ if blur_amount > 0:
138
+ kernel_size = max(int(mask_size * blur_amount), 3)
139
+ if kernel_size % 2 == 0:
140
+ kernel_size += 1
141
+ mask = cv2.GaussianBlur(mask, (kernel_size, kernel_size), 0)
142
+
143
+ mask = np.tile(np.expand_dims(mask, axis=-1), (1, 1, 3))
144
+
145
+ if blend_method == 'laplacian':
146
+ composite_image = laplacian_blending(foreground, background, mask.clip(0,1), num_levels=4)
147
+ else:
148
+ composite_image = mask * foreground + (1 - mask) * background
149
+
150
+ return composite_image.astype("uint8").clip(0, 255)
gfpgan/weights/detection_Resnet50_Final.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6d1de9c2944f2ccddca5f5e010ea5ae64a39845a86311af6fdf30841b0a5a16d
3
+ size 109497761
gfpgan/weights/parsing_parsenet.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3d558d8d0e42c20224f13cf5a29c79eba2d59913419f945545d8cf7b72920de2
3
+ size 85331193
nsfw_checker/LICENSE.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Copyright 2016, Yahoo Inc.
3
+
4
+ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
5
+
6
+ 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
7
+
8
+ 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
9
+
10
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
11
+
nsfw_checker/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . opennsfw import NSFWChecker
nsfw_checker/opennsfw.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import onnx
4
+ import onnxruntime
5
+ import numpy as np
6
+ from tqdm import tqdm
7
+
8
+ # https://github.com/yahoo/open_nsfw
9
+
10
+ class NSFWChecker:
11
+ def __init__(self, model_path=None, providers=["CPUExecutionProvider"]):
12
+ model = onnx.load(model_path)
13
+ self.input_name = model.graph.input[0].name
14
+ session_options = onnxruntime.SessionOptions()
15
+ self.session = onnxruntime.InferenceSession(model_path, sess_options=session_options, providers=providers)
16
+
17
+ def is_nsfw(self, img_paths, threshold = 0.85):
18
+ skip_step = 1
19
+ total_len = len(img_paths)
20
+ if total_len < 100: skip_step = 1
21
+ if total_len > 100 and total_len < 500: skip_step = 10
22
+ if total_len > 500 and total_len < 1000: skip_step = 20
23
+ if total_len > 1000 and total_len < 10000: skip_step = 50
24
+ if total_len > 10000: skip_step = 100
25
+
26
+ for idx in tqdm(range(0, total_len, skip_step), total=int(total_len // skip_step), desc="Checking for NSFW contents"):
27
+ img = cv2.imread(img_paths[idx])
28
+ img = cv2.resize(img, (224,224)).astype('float32')
29
+ img -= np.array([104, 117, 123], dtype=np.float32)
30
+ img = np.expand_dims(img, axis=0)
31
+
32
+ score = self.session.run(None, {self.input_name:img})[0][0][1]
33
+
34
+ if score > threshold:
35
+ print(f"Detected nsfw score:{score}")
36
+ return True
37
+ return False
nsfw_detector.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision.transforms import Normalize
2
+ import torchvision.transforms as T
3
+ import torch.nn as nn
4
+ from PIL import Image
5
+ import numpy as np
6
+ import torch
7
+ import timm
8
+ from tqdm import tqdm
9
+
10
+ # https://github.com/Whiax/NSFW-Classifier/raw/main/nsfwmodel_281.pth
11
+ normalize_t = Normalize((0.4814, 0.4578, 0.4082), (0.2686, 0.2613, 0.2757))
12
+
13
+ #nsfw classifier
14
+ class NSFWClassifier(nn.Module):
15
+ def __init__(self):
16
+ super().__init__()
17
+ nsfw_model=self
18
+ nsfw_model.root_model = timm.create_model('convnext_base_in22ft1k', pretrained=True)
19
+ nsfw_model.linear_probe = nn.Linear(1024, 1, bias=False)
20
+
21
+ def forward(self, x):
22
+ nsfw_model = self
23
+ x = normalize_t(x)
24
+ x = nsfw_model.root_model.stem(x)
25
+ x = nsfw_model.root_model.stages(x)
26
+ x = nsfw_model.root_model.head.global_pool(x)
27
+ x = nsfw_model.root_model.head.norm(x)
28
+ x = nsfw_model.root_model.head.flatten(x)
29
+ x = nsfw_model.linear_probe(x)
30
+ return x
31
+
32
+ def is_nsfw(self, img_paths, threshold = 0.98):
33
+ skip_step = 1
34
+ total_len = len(img_paths)
35
+ if total_len < 100: skip_step = 1
36
+ if total_len > 100 and total_len < 500: skip_step = 10
37
+ if total_len > 500 and total_len < 1000: skip_step = 20
38
+ if total_len > 1000 and total_len < 10000: skip_step = 50
39
+ if total_len > 10000: skip_step = 100
40
+
41
+ for idx in tqdm(range(0, total_len, skip_step), total=int(total_len // skip_step), desc="Checking for NSFW contents"):
42
+ _img = Image.open(img_paths[idx]).convert('RGB')
43
+ img = _img.resize((224, 224))
44
+ img = np.array(img)/255
45
+ img = T.ToTensor()(img).unsqueeze(0).float()
46
+ if next(self.parameters()).is_cuda:
47
+ img = img.cuda()
48
+ with torch.no_grad():
49
+ score = self.forward(img).sigmoid()[0].item()
50
+ if score > threshold:
51
+ print(f"Detected nsfw score:{score}")
52
+ _img.save("nsfw.jpg")
53
+ return True
54
+ return False
55
+
56
+ def get_nsfw_detector(model_path='nsfwmodel_281.pth', device="cpu"):
57
+ #load base model
58
+ nsfw_model = NSFWClassifier()
59
+ nsfw_model = nsfw_model.eval()
60
+ #load linear weights
61
+ linear_pth = model_path
62
+ linear_state_dict = torch.load(linear_pth, map_location='cpu')
63
+ nsfw_model.linear_probe.load_state_dict(linear_state_dict)
64
+ nsfw_model = nsfw_model.to(device)
65
+ return nsfw_model
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.0.1
2
+ torchvision
3
+ gradio>=3.33.1
4
+ insightface==0.7.3
5
+ moviepy>=1.0.3
6
+ numpy
7
+ onnx==1.14.0
8
+ onnxruntime==1.15.0
9
+ opencv-python>=4.7.0.72
10
+ opencv-python-headless>=4.7.0.72
11
+ gfpgan==1.3.8
12
+
upscaler/RealESRGAN/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model import RealESRGAN
upscaler/RealESRGAN/arch_utils.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn as nn
4
+ from torch.nn import functional as F
5
+ from torch.nn import init as init
6
+ from torch.nn.modules.batchnorm import _BatchNorm
7
+
8
+ @torch.no_grad()
9
+ def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
10
+ """Initialize network weights.
11
+
12
+ Args:
13
+ module_list (list[nn.Module] | nn.Module): Modules to be initialized.
14
+ scale (float): Scale initialized weights, especially for residual
15
+ blocks. Default: 1.
16
+ bias_fill (float): The value to fill bias. Default: 0
17
+ kwargs (dict): Other arguments for initialization function.
18
+ """
19
+ if not isinstance(module_list, list):
20
+ module_list = [module_list]
21
+ for module in module_list:
22
+ for m in module.modules():
23
+ if isinstance(m, nn.Conv2d):
24
+ init.kaiming_normal_(m.weight, **kwargs)
25
+ m.weight.data *= scale
26
+ if m.bias is not None:
27
+ m.bias.data.fill_(bias_fill)
28
+ elif isinstance(m, nn.Linear):
29
+ init.kaiming_normal_(m.weight, **kwargs)
30
+ m.weight.data *= scale
31
+ if m.bias is not None:
32
+ m.bias.data.fill_(bias_fill)
33
+ elif isinstance(m, _BatchNorm):
34
+ init.constant_(m.weight, 1)
35
+ if m.bias is not None:
36
+ m.bias.data.fill_(bias_fill)
37
+
38
+
39
+ def make_layer(basic_block, num_basic_block, **kwarg):
40
+ """Make layers by stacking the same blocks.
41
+
42
+ Args:
43
+ basic_block (nn.module): nn.module class for basic block.
44
+ num_basic_block (int): number of blocks.
45
+
46
+ Returns:
47
+ nn.Sequential: Stacked blocks in nn.Sequential.
48
+ """
49
+ layers = []
50
+ for _ in range(num_basic_block):
51
+ layers.append(basic_block(**kwarg))
52
+ return nn.Sequential(*layers)
53
+
54
+
55
+ class ResidualBlockNoBN(nn.Module):
56
+ """Residual block without BN.
57
+
58
+ It has a style of:
59
+ ---Conv-ReLU-Conv-+-
60
+ |________________|
61
+
62
+ Args:
63
+ num_feat (int): Channel number of intermediate features.
64
+ Default: 64.
65
+ res_scale (float): Residual scale. Default: 1.
66
+ pytorch_init (bool): If set to True, use pytorch default init,
67
+ otherwise, use default_init_weights. Default: False.
68
+ """
69
+
70
+ def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
71
+ super(ResidualBlockNoBN, self).__init__()
72
+ self.res_scale = res_scale
73
+ self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
74
+ self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
75
+ self.relu = nn.ReLU(inplace=True)
76
+
77
+ if not pytorch_init:
78
+ default_init_weights([self.conv1, self.conv2], 0.1)
79
+
80
+ def forward(self, x):
81
+ identity = x
82
+ out = self.conv2(self.relu(self.conv1(x)))
83
+ return identity + out * self.res_scale
84
+
85
+
86
+ class Upsample(nn.Sequential):
87
+ """Upsample module.
88
+
89
+ Args:
90
+ scale (int): Scale factor. Supported scales: 2^n and 3.
91
+ num_feat (int): Channel number of intermediate features.
92
+ """
93
+
94
+ def __init__(self, scale, num_feat):
95
+ m = []
96
+ if (scale & (scale - 1)) == 0: # scale = 2^n
97
+ for _ in range(int(math.log(scale, 2))):
98
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
99
+ m.append(nn.PixelShuffle(2))
100
+ elif scale == 3:
101
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
102
+ m.append(nn.PixelShuffle(3))
103
+ else:
104
+ raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
105
+ super(Upsample, self).__init__(*m)
106
+
107
+
108
+ def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True):
109
+ """Warp an image or feature map with optical flow.
110
+
111
+ Args:
112
+ x (Tensor): Tensor with size (n, c, h, w).
113
+ flow (Tensor): Tensor with size (n, h, w, 2), normal value.
114
+ interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
115
+ padding_mode (str): 'zeros' or 'border' or 'reflection'.
116
+ Default: 'zeros'.
117
+ align_corners (bool): Before pytorch 1.3, the default value is
118
+ align_corners=True. After pytorch 1.3, the default value is
119
+ align_corners=False. Here, we use the True as default.
120
+
121
+ Returns:
122
+ Tensor: Warped image or feature map.
123
+ """
124
+ assert x.size()[-2:] == flow.size()[1:3]
125
+ _, _, h, w = x.size()
126
+ # create mesh grid
127
+ grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x))
128
+ grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
129
+ grid.requires_grad = False
130
+
131
+ vgrid = grid + flow
132
+ # scale grid to [-1,1]
133
+ vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
134
+ vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
135
+ vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
136
+ output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners)
137
+
138
+ # TODO, what if align_corners=False
139
+ return output
140
+
141
+
142
+ def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False):
143
+ """Resize a flow according to ratio or shape.
144
+
145
+ Args:
146
+ flow (Tensor): Precomputed flow. shape [N, 2, H, W].
147
+ size_type (str): 'ratio' or 'shape'.
148
+ sizes (list[int | float]): the ratio for resizing or the final output
149
+ shape.
150
+ 1) The order of ratio should be [ratio_h, ratio_w]. For
151
+ downsampling, the ratio should be smaller than 1.0 (i.e., ratio
152
+ < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
153
+ ratio > 1.0).
154
+ 2) The order of output_size should be [out_h, out_w].
155
+ interp_mode (str): The mode of interpolation for resizing.
156
+ Default: 'bilinear'.
157
+ align_corners (bool): Whether align corners. Default: False.
158
+
159
+ Returns:
160
+ Tensor: Resized flow.
161
+ """
162
+ _, _, flow_h, flow_w = flow.size()
163
+ if size_type == 'ratio':
164
+ output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
165
+ elif size_type == 'shape':
166
+ output_h, output_w = sizes[0], sizes[1]
167
+ else:
168
+ raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.')
169
+
170
+ input_flow = flow.clone()
171
+ ratio_h = output_h / flow_h
172
+ ratio_w = output_w / flow_w
173
+ input_flow[:, 0, :, :] *= ratio_w
174
+ input_flow[:, 1, :, :] *= ratio_h
175
+ resized_flow = F.interpolate(
176
+ input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners)
177
+ return resized_flow
178
+
179
+
180
+ # TODO: may write a cpp file
181
+ def pixel_unshuffle(x, scale):
182
+ """ Pixel unshuffle.
183
+
184
+ Args:
185
+ x (Tensor): Input feature with shape (b, c, hh, hw).
186
+ scale (int): Downsample ratio.
187
+
188
+ Returns:
189
+ Tensor: the pixel unshuffled feature.
190
+ """
191
+ b, c, hh, hw = x.size()
192
+ out_channel = c * (scale**2)
193
+ assert hh % scale == 0 and hw % scale == 0
194
+ h = hh // scale
195
+ w = hw // scale
196
+ x_view = x.view(b, c, h, scale, w, scale)
197
+ return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
upscaler/RealESRGAN/model.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from torch.nn import functional as F
4
+ from PIL import Image
5
+ import numpy as np
6
+ import cv2
7
+
8
+ from .rrdbnet_arch import RRDBNet
9
+ from .utils import pad_reflect, split_image_into_overlapping_patches, stich_together, \
10
+ unpad_image
11
+
12
+
13
+ HF_MODELS = {
14
+ 2: dict(
15
+ repo_id='sberbank-ai/Real-ESRGAN',
16
+ filename='RealESRGAN_x2.pth',
17
+ ),
18
+ 4: dict(
19
+ repo_id='sberbank-ai/Real-ESRGAN',
20
+ filename='RealESRGAN_x4.pth',
21
+ ),
22
+ 8: dict(
23
+ repo_id='sberbank-ai/Real-ESRGAN',
24
+ filename='RealESRGAN_x8.pth',
25
+ ),
26
+ }
27
+
28
+
29
+ class RealESRGAN:
30
+ def __init__(self, device, scale=4):
31
+ self.device = device
32
+ self.scale = scale
33
+ self.model = RRDBNet(
34
+ num_in_ch=3, num_out_ch=3, num_feat=64,
35
+ num_block=23, num_grow_ch=32, scale=scale
36
+ )
37
+
38
+ def load_weights(self, model_path, download=True):
39
+ if not os.path.exists(model_path) and download:
40
+ from huggingface_hub import hf_hub_url, cached_download
41
+ assert self.scale in [2,4,8], 'You can download models only with scales: 2, 4, 8'
42
+ config = HF_MODELS[self.scale]
43
+ cache_dir = os.path.dirname(model_path)
44
+ local_filename = os.path.basename(model_path)
45
+ config_file_url = hf_hub_url(repo_id=config['repo_id'], filename=config['filename'])
46
+ cached_download(config_file_url, cache_dir=cache_dir, force_filename=local_filename)
47
+ print('Weights downloaded to:', os.path.join(cache_dir, local_filename))
48
+
49
+ loadnet = torch.load(model_path)
50
+ if 'params' in loadnet:
51
+ self.model.load_state_dict(loadnet['params'], strict=True)
52
+ elif 'params_ema' in loadnet:
53
+ self.model.load_state_dict(loadnet['params_ema'], strict=True)
54
+ else:
55
+ self.model.load_state_dict(loadnet, strict=True)
56
+ self.model.eval()
57
+ self.model.to(self.device)
58
+
59
+ @torch.cuda.amp.autocast()
60
+ def predict(self, lr_image, batch_size=4, patches_size=192,
61
+ padding=24, pad_size=15):
62
+ scale = self.scale
63
+ device = self.device
64
+ lr_image = np.array(lr_image)
65
+ lr_image = pad_reflect(lr_image, pad_size)
66
+
67
+ patches, p_shape = split_image_into_overlapping_patches(
68
+ lr_image, patch_size=patches_size, padding_size=padding
69
+ )
70
+ img = torch.FloatTensor(patches/255).permute((0,3,1,2)).to(device).detach()
71
+
72
+ with torch.no_grad():
73
+ res = self.model(img[0:batch_size])
74
+ for i in range(batch_size, img.shape[0], batch_size):
75
+ res = torch.cat((res, self.model(img[i:i+batch_size])), 0)
76
+
77
+ sr_image = res.permute((0,2,3,1)).clamp_(0, 1).cpu()
78
+ np_sr_image = sr_image.numpy()
79
+
80
+ padded_size_scaled = tuple(np.multiply(p_shape[0:2], scale)) + (3,)
81
+ scaled_image_shape = tuple(np.multiply(lr_image.shape[0:2], scale)) + (3,)
82
+ np_sr_image = stich_together(
83
+ np_sr_image, padded_image_shape=padded_size_scaled,
84
+ target_shape=scaled_image_shape, padding_size=padding * scale
85
+ )
86
+ sr_img = (np_sr_image*255).astype(np.uint8)
87
+ sr_img = unpad_image(sr_img, pad_size*scale)
88
+ #sr_img = Image.fromarray(sr_img)
89
+
90
+ return sr_img
upscaler/RealESRGAN/rrdbnet_arch.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn as nn
3
+ from torch.nn import functional as F
4
+
5
+ from .arch_utils import default_init_weights, make_layer, pixel_unshuffle
6
+
7
+
8
+ class ResidualDenseBlock(nn.Module):
9
+ """Residual Dense Block.
10
+
11
+ Used in RRDB block in ESRGAN.
12
+
13
+ Args:
14
+ num_feat (int): Channel number of intermediate features.
15
+ num_grow_ch (int): Channels for each growth.
16
+ """
17
+
18
+ def __init__(self, num_feat=64, num_grow_ch=32):
19
+ super(ResidualDenseBlock, self).__init__()
20
+ self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
21
+ self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
22
+ self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
23
+ self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
24
+ self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
25
+
26
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
27
+
28
+ # initialization
29
+ default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
30
+
31
+ def forward(self, x):
32
+ x1 = self.lrelu(self.conv1(x))
33
+ x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
34
+ x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
35
+ x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
36
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
37
+ # Emperically, we use 0.2 to scale the residual for better performance
38
+ return x5 * 0.2 + x
39
+
40
+
41
+ class RRDB(nn.Module):
42
+ """Residual in Residual Dense Block.
43
+
44
+ Used in RRDB-Net in ESRGAN.
45
+
46
+ Args:
47
+ num_feat (int): Channel number of intermediate features.
48
+ num_grow_ch (int): Channels for each growth.
49
+ """
50
+
51
+ def __init__(self, num_feat, num_grow_ch=32):
52
+ super(RRDB, self).__init__()
53
+ self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
54
+ self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
55
+ self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
56
+
57
+ def forward(self, x):
58
+ out = self.rdb1(x)
59
+ out = self.rdb2(out)
60
+ out = self.rdb3(out)
61
+ # Emperically, we use 0.2 to scale the residual for better performance
62
+ return out * 0.2 + x
63
+
64
+
65
+ class RRDBNet(nn.Module):
66
+ """Networks consisting of Residual in Residual Dense Block, which is used
67
+ in ESRGAN.
68
+
69
+ ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
70
+
71
+ We extend ESRGAN for scale x2 and scale x1.
72
+ Note: This is one option for scale 1, scale 2 in RRDBNet.
73
+ We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
74
+ and enlarge the channel size before feeding inputs into the main ESRGAN architecture.
75
+
76
+ Args:
77
+ num_in_ch (int): Channel number of inputs.
78
+ num_out_ch (int): Channel number of outputs.
79
+ num_feat (int): Channel number of intermediate features.
80
+ Default: 64
81
+ num_block (int): Block number in the trunk network. Defaults: 23
82
+ num_grow_ch (int): Channels for each growth. Default: 32.
83
+ """
84
+
85
+ def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
86
+ super(RRDBNet, self).__init__()
87
+ self.scale = scale
88
+ if scale == 2:
89
+ num_in_ch = num_in_ch * 4
90
+ elif scale == 1:
91
+ num_in_ch = num_in_ch * 16
92
+ self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
93
+ self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
94
+ self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
95
+ # upsample
96
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
97
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
98
+ if scale == 8:
99
+ self.conv_up3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
100
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
101
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
102
+
103
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
104
+
105
+ def forward(self, x):
106
+ if self.scale == 2:
107
+ feat = pixel_unshuffle(x, scale=2)
108
+ elif self.scale == 1:
109
+ feat = pixel_unshuffle(x, scale=4)
110
+ else:
111
+ feat = x
112
+ feat = self.conv_first(feat)
113
+ body_feat = self.conv_body(self.body(feat))
114
+ feat = feat + body_feat
115
+ # upsample
116
+ feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
117
+ feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
118
+ if self.scale == 8:
119
+ feat = self.lrelu(self.conv_up3(F.interpolate(feat, scale_factor=2, mode='nearest')))
120
+ out = self.conv_last(self.lrelu(self.conv_hr(feat)))
121
+ return out
upscaler/RealESRGAN/utils.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from PIL import Image
4
+ import os
5
+ import io
6
+
7
+ def pad_reflect(image, pad_size):
8
+ imsize = image.shape
9
+ height, width = imsize[:2]
10
+ new_img = np.zeros([height+pad_size*2, width+pad_size*2, imsize[2]]).astype(np.uint8)
11
+ new_img[pad_size:-pad_size, pad_size:-pad_size, :] = image
12
+
13
+ new_img[0:pad_size, pad_size:-pad_size, :] = np.flip(image[0:pad_size, :, :], axis=0) #top
14
+ new_img[-pad_size:, pad_size:-pad_size, :] = np.flip(image[-pad_size:, :, :], axis=0) #bottom
15
+ new_img[:, 0:pad_size, :] = np.flip(new_img[:, pad_size:pad_size*2, :], axis=1) #left
16
+ new_img[:, -pad_size:, :] = np.flip(new_img[:, -pad_size*2:-pad_size, :], axis=1) #right
17
+
18
+ return new_img
19
+
20
+ def unpad_image(image, pad_size):
21
+ return image[pad_size:-pad_size, pad_size:-pad_size, :]
22
+
23
+
24
+ def process_array(image_array, expand=True):
25
+ """ Process a 3-dimensional array into a scaled, 4 dimensional batch of size 1. """
26
+
27
+ image_batch = image_array / 255.0
28
+ if expand:
29
+ image_batch = np.expand_dims(image_batch, axis=0)
30
+ return image_batch
31
+
32
+
33
+ def process_output(output_tensor):
34
+ """ Transforms the 4-dimensional output tensor into a suitable image format. """
35
+
36
+ sr_img = output_tensor.clip(0, 1) * 255
37
+ sr_img = np.uint8(sr_img)
38
+ return sr_img
39
+
40
+
41
+ def pad_patch(image_patch, padding_size, channel_last=True):
42
+ """ Pads image_patch with with padding_size edge values. """
43
+
44
+ if channel_last:
45
+ return np.pad(
46
+ image_patch,
47
+ ((padding_size, padding_size), (padding_size, padding_size), (0, 0)),
48
+ 'edge',
49
+ )
50
+ else:
51
+ return np.pad(
52
+ image_patch,
53
+ ((0, 0), (padding_size, padding_size), (padding_size, padding_size)),
54
+ 'edge',
55
+ )
56
+
57
+
58
+ def unpad_patches(image_patches, padding_size):
59
+ return image_patches[:, padding_size:-padding_size, padding_size:-padding_size, :]
60
+
61
+
62
+ def split_image_into_overlapping_patches(image_array, patch_size, padding_size=2):
63
+ """ Splits the image into partially overlapping patches.
64
+ The patches overlap by padding_size pixels.
65
+ Pads the image twice:
66
+ - first to have a size multiple of the patch size,
67
+ - then to have equal padding at the borders.
68
+ Args:
69
+ image_array: numpy array of the input image.
70
+ patch_size: size of the patches from the original image (without padding).
71
+ padding_size: size of the overlapping area.
72
+ """
73
+
74
+ xmax, ymax, _ = image_array.shape
75
+ x_remainder = xmax % patch_size
76
+ y_remainder = ymax % patch_size
77
+
78
+ # modulo here is to avoid extending of patch_size instead of 0
79
+ x_extend = (patch_size - x_remainder) % patch_size
80
+ y_extend = (patch_size - y_remainder) % patch_size
81
+
82
+ # make sure the image is divisible into regular patches
83
+ extended_image = np.pad(image_array, ((0, x_extend), (0, y_extend), (0, 0)), 'edge')
84
+
85
+ # add padding around the image to simplify computations
86
+ padded_image = pad_patch(extended_image, padding_size, channel_last=True)
87
+
88
+ xmax, ymax, _ = padded_image.shape
89
+ patches = []
90
+
91
+ x_lefts = range(padding_size, xmax - padding_size, patch_size)
92
+ y_tops = range(padding_size, ymax - padding_size, patch_size)
93
+
94
+ for x in x_lefts:
95
+ for y in y_tops:
96
+ x_left = x - padding_size
97
+ y_top = y - padding_size
98
+ x_right = x + patch_size + padding_size
99
+ y_bottom = y + patch_size + padding_size
100
+ patch = padded_image[x_left:x_right, y_top:y_bottom, :]
101
+ patches.append(patch)
102
+
103
+ return np.array(patches), padded_image.shape
104
+
105
+
106
+ def stich_together(patches, padded_image_shape, target_shape, padding_size=4):
107
+ """ Reconstruct the image from overlapping patches.
108
+ After scaling, shapes and padding should be scaled too.
109
+ Args:
110
+ patches: patches obtained with split_image_into_overlapping_patches
111
+ padded_image_shape: shape of the padded image contructed in split_image_into_overlapping_patches
112
+ target_shape: shape of the final image
113
+ padding_size: size of the overlapping area.
114
+ """
115
+
116
+ xmax, ymax, _ = padded_image_shape
117
+ patches = unpad_patches(patches, padding_size)
118
+ patch_size = patches.shape[1]
119
+ n_patches_per_row = ymax // patch_size
120
+
121
+ complete_image = np.zeros((xmax, ymax, 3))
122
+
123
+ row = -1
124
+ col = 0
125
+ for i in range(len(patches)):
126
+ if i % n_patches_per_row == 0:
127
+ row += 1
128
+ col = 0
129
+ complete_image[
130
+ row * patch_size: (row + 1) * patch_size, col * patch_size: (col + 1) * patch_size,:
131
+ ] = patches[i]
132
+ col += 1
133
+ return complete_image[0: target_shape[0], 0: target_shape[1], :]
upscaler/__init__.py ADDED
File without changes
upscaler/codeformer.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import onnx
4
+ import onnxruntime
5
+ import numpy as np
6
+
7
+ import time
8
+
9
+ # codeformer converted to onnx
10
+ # using https://github.com/redthing1/CodeFormer
11
+
12
+
13
+ class CodeFormerEnhancer:
14
+ def __init__(self, model_path="codeformer.onnx", device='cpu'):
15
+ model = onnx.load(model_path)
16
+ session_options = onnxruntime.SessionOptions()
17
+ session_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
18
+ providers = ["CPUExecutionProvider"]
19
+ if device == 'cuda':
20
+ providers = [("CUDAExecutionProvider", {"cudnn_conv_algo_search": "DEFAULT"}),"CPUExecutionProvider"]
21
+ self.session = onnxruntime.InferenceSession(model_path, sess_options=session_options, providers=providers)
22
+
23
+ def enhance(self, img, w=0.9):
24
+ img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
25
+ img = img.astype(np.float32)[:,:,::-1] / 255.0
26
+ img = img.transpose((2, 0, 1))
27
+ nrm_mean = np.array([0.5, 0.5, 0.5]).reshape((-1, 1, 1))
28
+ nrm_std = np.array([0.5, 0.5, 0.5]).reshape((-1, 1, 1))
29
+ img = (img - nrm_mean) / nrm_std
30
+
31
+ img = np.expand_dims(img, axis=0)
32
+
33
+ out = self.session.run(None, {'x':img.astype(np.float32), 'w':np.array([w], dtype=np.double)})[0]
34
+ out = (out[0].transpose(1,2,0).clip(-1,1) + 1) * 0.5
35
+ out = (out * 255)[:,:,::-1]
36
+
37
+ return out.astype('uint8')
utils.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import time
4
+ import glob
5
+ import shutil
6
+ import platform
7
+ import datetime
8
+ import subprocess
9
+ import numpy as np
10
+ from threading import Thread
11
+ from moviepy.editor import VideoFileClip, ImageSequenceClip
12
+ from moviepy.video.io.ffmpeg_tools import ffmpeg_extract_subclip
13
+
14
+
15
+ logo_image = cv2.imread("./assets/images/logo.png", cv2.IMREAD_UNCHANGED)
16
+
17
+
18
+ quality_types = ["poor", "low", "medium", "high", "best"]
19
+
20
+
21
+ bitrate_quality_by_resolution = {
22
+ 240: {"poor": "300k", "low": "500k", "medium": "800k", "high": "1000k", "best": "1200k"},
23
+ 360: {"poor": "500k","low": "800k","medium": "1200k","high": "1500k","best": "2000k"},
24
+ 480: {"poor": "800k","low": "1200k","medium": "2000k","high": "2500k","best": "3000k"},
25
+ 720: {"poor": "1500k","low": "2500k","medium": "4000k","high": "5000k","best": "6000k"},
26
+ 1080: {"poor": "2500k","low": "4000k","medium": "6000k","high": "7000k","best": "8000k"},
27
+ 1440: {"poor": "4000k","low": "6000k","medium": "8000k","high": "10000k","best": "12000k"},
28
+ 2160: {"poor": "8000k","low": "10000k","medium": "12000k","high": "15000k","best": "20000k"}
29
+ }
30
+
31
+
32
+ crf_quality_by_resolution = {
33
+ 240: {"poor": 45, "low": 35, "medium": 28, "high": 23, "best": 20},
34
+ 360: {"poor": 35, "low": 28, "medium": 23, "high": 20, "best": 18},
35
+ 480: {"poor": 28, "low": 23, "medium": 20, "high": 18, "best": 16},
36
+ 720: {"poor": 23, "low": 20, "medium": 18, "high": 16, "best": 14},
37
+ 1080: {"poor": 20, "low": 18, "medium": 16, "high": 14, "best": 12},
38
+ 1440: {"poor": 18, "low": 16, "medium": 14, "high": 12, "best": 10},
39
+ 2160: {"poor": 16, "low": 14, "medium": 12, "high": 10, "best": 8}
40
+ }
41
+
42
+
43
+ def get_bitrate_for_resolution(resolution, quality):
44
+ available_resolutions = list(bitrate_quality_by_resolution.keys())
45
+ closest_resolution = min(available_resolutions, key=lambda x: abs(x - resolution))
46
+ return bitrate_quality_by_resolution[closest_resolution][quality]
47
+
48
+
49
+ def get_crf_for_resolution(resolution, quality):
50
+ available_resolutions = list(crf_quality_by_resolution.keys())
51
+ closest_resolution = min(available_resolutions, key=lambda x: abs(x - resolution))
52
+ return crf_quality_by_resolution[closest_resolution][quality]
53
+
54
+
55
+ def get_video_bitrate(video_file):
56
+ ffprobe_cmd = ['ffprobe', '-v', 'error', '-select_streams', 'v:0', '-show_entries',
57
+ 'stream=bit_rate', '-of', 'default=noprint_wrappers=1:nokey=1', video_file]
58
+ result = subprocess.run(ffprobe_cmd, stdout=subprocess.PIPE)
59
+ kbps = max(int(result.stdout) // 1000, 10)
60
+ return str(kbps) + 'k'
61
+
62
+
63
+ def trim_video(video_path, output_path, start_frame, stop_frame):
64
+ video_name, _ = os.path.splitext(os.path.basename(video_path))
65
+ trimmed_video_filename = video_name + "_trimmed" + ".mp4"
66
+ temp_path = os.path.join(output_path, "trim")
67
+ os.makedirs(temp_path, exist_ok=True)
68
+ trimmed_video_file_path = os.path.join(temp_path, trimmed_video_filename)
69
+
70
+ video = VideoFileClip(video_path, fps_source="fps")
71
+ fps = video.fps
72
+ start_time = start_frame / fps
73
+ duration = (stop_frame - start_frame) / fps
74
+
75
+ bitrate = get_bitrate_for_resolution(min(*video.size), "high")
76
+
77
+ trimmed_video = video.subclip(start_time, start_time + duration)
78
+ trimmed_video.write_videofile(
79
+ trimmed_video_file_path, codec="libx264", audio_codec="aac", bitrate=bitrate,
80
+ )
81
+ trimmed_video.close()
82
+ video.close()
83
+
84
+ return trimmed_video_file_path
85
+
86
+
87
+ def open_directory(path=None):
88
+ if path is None:
89
+ return
90
+ try:
91
+ os.startfile(path)
92
+ except:
93
+ subprocess.Popen(["xdg-open", path])
94
+
95
+
96
+ class StreamerThread(object):
97
+ def __init__(self, src=0):
98
+ self.capture = cv2.VideoCapture(src)
99
+ self.capture.set(cv2.CAP_PROP_BUFFERSIZE, 2)
100
+ self.FPS = 1 / 30
101
+ self.FPS_MS = int(self.FPS * 1000)
102
+ self.thread = None
103
+ self.stopped = False
104
+ self.frame = None
105
+
106
+ def start(self):
107
+ self.thread = Thread(target=self.update, args=())
108
+ self.thread.daemon = True
109
+ self.thread.start()
110
+
111
+ def stop(self):
112
+ self.stopped = True
113
+ self.thread.join()
114
+ print("stopped")
115
+
116
+ def update(self):
117
+ while not self.stopped:
118
+ if self.capture.isOpened():
119
+ (self.status, self.frame) = self.capture.read()
120
+ time.sleep(self.FPS)
121
+
122
+
123
+ class ProcessBar:
124
+ def __init__(self, bar_length, total, before="⬛", after="🟨"):
125
+ self.bar_length = bar_length
126
+ self.total = total
127
+ self.before = before
128
+ self.after = after
129
+ self.bar = [self.before] * bar_length
130
+ self.start_time = time.time()
131
+
132
+ def get(self, index):
133
+ total = self.total
134
+ elapsed_time = time.time() - self.start_time
135
+ average_time_per_iteration = elapsed_time / (index + 1)
136
+ remaining_iterations = total - (index + 1)
137
+ estimated_remaining_time = remaining_iterations * average_time_per_iteration
138
+
139
+ self.bar[int(index / total * self.bar_length)] = self.after
140
+ info_text = f"({index+1}/{total}) {''.join(self.bar)} "
141
+ info_text += f"(ETR: {int(estimated_remaining_time // 60)} min {int(estimated_remaining_time % 60)} sec)"
142
+ return info_text
143
+
144
+
145
+ def add_logo_to_image(img, logo=logo_image):
146
+ logo_size = int(img.shape[1] * 0.1)
147
+ logo = cv2.resize(logo, (logo_size, logo_size))
148
+ if logo.shape[2] == 4:
149
+ alpha = logo[:, :, 3]
150
+ else:
151
+ alpha = np.ones_like(logo[:, :, 0]) * 255
152
+ padding = int(logo_size * 0.1)
153
+ roi = img.shape[0] - logo_size - padding, img.shape[1] - logo_size - padding
154
+ for c in range(0, 3):
155
+ img[roi[0] : roi[0] + logo_size, roi[1] : roi[1] + logo_size, c] = (
156
+ alpha / 255.0
157
+ ) * logo[:, :, c] + (1 - alpha / 255.0) * img[
158
+ roi[0] : roi[0] + logo_size, roi[1] : roi[1] + logo_size, c
159
+ ]
160
+ return img
161
+
162
+
163
+ def split_list_by_lengths(data, length_list):
164
+ split_data = []
165
+ start_idx = 0
166
+ for length in length_list:
167
+ end_idx = start_idx + length
168
+ sublist = data[start_idx:end_idx]
169
+ split_data.append(sublist)
170
+ start_idx = end_idx
171
+ return split_data
172
+
173
+
174
+ def merge_img_sequence_from_ref(ref_video_path, image_sequence, output_file_name):
175
+ video_clip = VideoFileClip(ref_video_path, fps_source="fps")
176
+ fps = video_clip.fps
177
+ duration = video_clip.duration
178
+ total_frames = video_clip.reader.nframes
179
+ audio_clip = video_clip.audio if video_clip.audio is not None else None
180
+ edited_video_clip = ImageSequenceClip(image_sequence, fps=fps)
181
+
182
+ if audio_clip is not None:
183
+ edited_video_clip = edited_video_clip.set_audio(audio_clip)
184
+
185
+ bitrate = get_bitrate_for_resolution(min(*edited_video_clip.size), "high")
186
+
187
+ edited_video_clip.set_duration(duration).write_videofile(
188
+ output_file_name, codec="libx264", bitrate=bitrate,
189
+ )
190
+ edited_video_clip.close()
191
+ video_clip.close()
192
+
193
+
194
+ def scale_bbox_from_center(bbox, scale_width, scale_height, image_width, image_height):
195
+ # Extract the coordinates of the bbox
196
+ x1, y1, x2, y2 = bbox
197
+
198
+ # Calculate the center point of the bbox
199
+ center_x = (x1 + x2) / 2
200
+ center_y = (y1 + y2) / 2
201
+
202
+ # Calculate the new width and height of the bbox based on the scaling factors
203
+ width = x2 - x1
204
+ height = y2 - y1
205
+ new_width = width * scale_width
206
+ new_height = height * scale_height
207
+
208
+ # Calculate the new coordinates of the bbox, considering the image boundaries
209
+ new_x1 = center_x - new_width / 2
210
+ new_y1 = center_y - new_height / 2
211
+ new_x2 = center_x + new_width / 2
212
+ new_y2 = center_y + new_height / 2
213
+
214
+ # Adjust the coordinates to ensure the bbox remains within the image boundaries
215
+ new_x1 = max(0, new_x1)
216
+ new_y1 = max(0, new_y1)
217
+ new_x2 = min(image_width - 1, new_x2)
218
+ new_y2 = min(image_height - 1, new_y2)
219
+
220
+ # Return the scaled bbox coordinates
221
+ scaled_bbox = [new_x1, new_y1, new_x2, new_y2]
222
+ return scaled_bbox
223
+
224
+
225
+ def laplacian_blending(A, B, m, num_levels=7):
226
+ assert A.shape == B.shape
227
+ assert B.shape == m.shape
228
+ height = m.shape[0]
229
+ width = m.shape[1]
230
+ size_list = np.array([4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192])
231
+ size = size_list[np.where(size_list > max(height, width))][0]
232
+ GA = np.zeros((size, size, 3), dtype=np.float32)
233
+ GA[:height, :width, :] = A
234
+ GB = np.zeros((size, size, 3), dtype=np.float32)
235
+ GB[:height, :width, :] = B
236
+ GM = np.zeros((size, size, 3), dtype=np.float32)
237
+ GM[:height, :width, :] = m
238
+ gpA = [GA]
239
+ gpB = [GB]
240
+ gpM = [GM]
241
+ for i in range(num_levels):
242
+ GA = cv2.pyrDown(GA)
243
+ GB = cv2.pyrDown(GB)
244
+ GM = cv2.pyrDown(GM)
245
+ gpA.append(np.float32(GA))
246
+ gpB.append(np.float32(GB))
247
+ gpM.append(np.float32(GM))
248
+ lpA = [gpA[num_levels-1]]
249
+ lpB = [gpB[num_levels-1]]
250
+ gpMr = [gpM[num_levels-1]]
251
+ for i in range(num_levels-1,0,-1):
252
+ LA = np.subtract(gpA[i-1], cv2.pyrUp(gpA[i]))
253
+ LB = np.subtract(gpB[i-1], cv2.pyrUp(gpB[i]))
254
+ lpA.append(LA)
255
+ lpB.append(LB)
256
+ gpMr.append(gpM[i-1])
257
+ LS = []
258
+ for la,lb,gm in zip(lpA,lpB,gpMr):
259
+ ls = la * gm + lb * (1.0 - gm)
260
+ LS.append(ls)
261
+ ls_ = LS[0]
262
+ for i in range(1,num_levels):
263
+ ls_ = cv2.pyrUp(ls_)
264
+ ls_ = cv2.add(ls_, LS[i])
265
+ ls_ = ls_[:height, :width, :]
266
+ #ls_ = (ls_ - np.min(ls_)) * (255.0 / (np.max(ls_) - np.min(ls_)))
267
+ return ls_.clip(0, 255)
268
+
269
+
270
+ def mask_crop(mask, crop):
271
+ top, bottom, left, right = crop
272
+ shape = mask.shape
273
+ top = int(top)
274
+ bottom = int(bottom)
275
+ if top + bottom < shape[1]:
276
+ if top > 0: mask[:top, :] = 0
277
+ if bottom > 0: mask[-bottom:, :] = 0
278
+
279
+ left = int(left)
280
+ right = int(right)
281
+ if left + right < shape[0]:
282
+ if left > 0: mask[:, :left] = 0
283
+ if right > 0: mask[:, -right:] = 0
284
+
285
+ return mask
286
+
287
+ def create_image_grid(images, size=128):
288
+ num_images = len(images)
289
+ num_cols = int(np.ceil(np.sqrt(num_images)))
290
+ num_rows = int(np.ceil(num_images / num_cols))
291
+ grid = np.zeros((num_rows * size, num_cols * size, 3), dtype=np.uint8)
292
+
293
+ for i, image in enumerate(images):
294
+ row_idx = (i // num_cols) * size
295
+ col_idx = (i % num_cols) * size
296
+ image = cv2.resize(image.copy(), (size,size))
297
+ if image.dtype != np.uint8:
298
+ image = (image.astype('float32') * 255).astype('uint8')
299
+ if image.ndim == 2:
300
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
301
+ grid[row_idx:row_idx + size, col_idx:col_idx + size] = image
302
+
303
+ return grid