Spaces:
Running
Running
Upload 12 files
Browse files- roop/processors/Enhance_CodeFormer.py +71 -0
- roop/processors/Enhance_DMDNet.py +898 -0
- roop/processors/Enhance_GFPGAN.py +73 -0
- roop/processors/Enhance_GPEN.py +63 -0
- roop/processors/Enhance_RestoreFormerPPlus.py +64 -0
- roop/processors/FaceSwapInsightFace.py +61 -0
- roop/processors/Frame_Colorizer.py +70 -0
- roop/processors/Frame_Filter.py +105 -0
- roop/processors/Frame_Masking.py +71 -0
- roop/processors/Frame_Upscale.py +129 -0
- roop/processors/Mask_Clip2Seg.py +94 -0
- roop/processors/Mask_XSeg.py +58 -0
roop/processors/Enhance_CodeFormer.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, List, Callable
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import onnxruntime
|
5 |
+
import roop.globals
|
6 |
+
|
7 |
+
from roop.typing import Face, Frame, FaceSet
|
8 |
+
from roop.utilities import resolve_relative_path
|
9 |
+
|
10 |
+
class Enhance_CodeFormer():
|
11 |
+
model_codeformer = None
|
12 |
+
|
13 |
+
plugin_options:dict = None
|
14 |
+
|
15 |
+
processorname = 'codeformer'
|
16 |
+
type = 'enhance'
|
17 |
+
|
18 |
+
|
19 |
+
def Initialize(self, plugin_options:dict):
|
20 |
+
if self.plugin_options is not None:
|
21 |
+
if self.plugin_options["devicename"] != plugin_options["devicename"]:
|
22 |
+
self.Release()
|
23 |
+
|
24 |
+
self.plugin_options = plugin_options
|
25 |
+
if self.model_codeformer is None:
|
26 |
+
# replace Mac mps with cpu for the moment
|
27 |
+
self.devicename = self.plugin_options["devicename"].replace('mps', 'cpu')
|
28 |
+
model_path = resolve_relative_path('../models/CodeFormer/CodeFormerv0.1.onnx')
|
29 |
+
self.model_codeformer = onnxruntime.InferenceSession(model_path, None, providers=roop.globals.execution_providers)
|
30 |
+
self.model_inputs = self.model_codeformer.get_inputs()
|
31 |
+
model_outputs = self.model_codeformer.get_outputs()
|
32 |
+
self.io_binding = self.model_codeformer.io_binding()
|
33 |
+
self.io_binding.bind_cpu_input(self.model_inputs[1].name, np.array([0.5]))
|
34 |
+
self.io_binding.bind_output(model_outputs[0].name, self.devicename)
|
35 |
+
|
36 |
+
|
37 |
+
def Run(self, source_faceset: FaceSet, target_face: Face, temp_frame: Frame) -> Frame:
|
38 |
+
input_size = temp_frame.shape[1]
|
39 |
+
# preprocess
|
40 |
+
temp_frame = cv2.resize(temp_frame, (512, 512), cv2.INTER_CUBIC)
|
41 |
+
temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB)
|
42 |
+
temp_frame = temp_frame.astype('float32') / 255.0
|
43 |
+
temp_frame = (temp_frame - 0.5) / 0.5
|
44 |
+
temp_frame = np.expand_dims(temp_frame, axis=0).transpose(0, 3, 1, 2)
|
45 |
+
|
46 |
+
self.io_binding.bind_cpu_input(self.model_inputs[0].name, temp_frame.astype(np.float32))
|
47 |
+
self.model_codeformer.run_with_iobinding(self.io_binding)
|
48 |
+
ort_outs = self.io_binding.copy_outputs_to_cpu()
|
49 |
+
result = ort_outs[0][0]
|
50 |
+
del ort_outs
|
51 |
+
|
52 |
+
# post-process
|
53 |
+
result = result.transpose((1, 2, 0))
|
54 |
+
|
55 |
+
un_min = -1.0
|
56 |
+
un_max = 1.0
|
57 |
+
result = np.clip(result, un_min, un_max)
|
58 |
+
result = (result - un_min) / (un_max - un_min)
|
59 |
+
|
60 |
+
result = cv2.cvtColor(result, cv2.COLOR_RGB2BGR)
|
61 |
+
result = (result * 255.0).round()
|
62 |
+
scale_factor = int(result.shape[1] / input_size)
|
63 |
+
return result.astype(np.uint8), scale_factor
|
64 |
+
|
65 |
+
|
66 |
+
def Release(self):
|
67 |
+
del self.model_codeformer
|
68 |
+
self.model_codeformer = None
|
69 |
+
del self.io_binding
|
70 |
+
self.io_binding = None
|
71 |
+
|
roop/processors/Enhance_DMDNet.py
ADDED
@@ -0,0 +1,898 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, List, Callable
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import torch.nn.utils.spectral_norm as SpectralNorm
|
8 |
+
import threading
|
9 |
+
from torchvision.ops import roi_align
|
10 |
+
|
11 |
+
from math import sqrt
|
12 |
+
|
13 |
+
from torchvision.transforms.functional import normalize
|
14 |
+
|
15 |
+
from roop.typing import Face, Frame, FaceSet
|
16 |
+
|
17 |
+
|
18 |
+
THREAD_LOCK_DMDNET = threading.Lock()
|
19 |
+
|
20 |
+
|
21 |
+
class Enhance_DMDNet():
|
22 |
+
plugin_options:dict = None
|
23 |
+
model_dmdnet = None
|
24 |
+
torchdevice = None
|
25 |
+
|
26 |
+
processorname = 'dmdnet'
|
27 |
+
type = 'enhance'
|
28 |
+
|
29 |
+
|
30 |
+
def Initialize(self, plugin_options:dict):
|
31 |
+
if self.plugin_options is not None:
|
32 |
+
if self.plugin_options["devicename"] != plugin_options["devicename"]:
|
33 |
+
self.Release()
|
34 |
+
|
35 |
+
self.plugin_options = plugin_options
|
36 |
+
if self.model_dmdnet is None:
|
37 |
+
self.model_dmdnet = self.create(self.plugin_options["devicename"])
|
38 |
+
|
39 |
+
|
40 |
+
# temp_frame already cropped+aligned, bbox not
|
41 |
+
def Run(self, source_faceset: FaceSet, target_face: Face, temp_frame: Frame) -> Frame:
|
42 |
+
input_size = temp_frame.shape[1]
|
43 |
+
|
44 |
+
result = self.enhance_face(source_faceset, temp_frame, target_face)
|
45 |
+
scale_factor = int(result.shape[1] / input_size)
|
46 |
+
return result.astype(np.uint8), scale_factor
|
47 |
+
|
48 |
+
|
49 |
+
def Release(self):
|
50 |
+
self.model_gfpgan = None
|
51 |
+
|
52 |
+
|
53 |
+
# https://stackoverflow.com/a/67174339
|
54 |
+
def landmarks106_to_68(self, pt106):
|
55 |
+
map106to68=[1,10,12,14,16,3,5,7,0,23,21,19,32,30,28,26,17,
|
56 |
+
43,48,49,51,50,
|
57 |
+
102,103,104,105,101,
|
58 |
+
72,73,74,86,78,79,80,85,84,
|
59 |
+
35,41,42,39,37,36,
|
60 |
+
89,95,96,93,91,90,
|
61 |
+
52,64,63,71,67,68,61,58,59,53,56,55,65,66,62,70,69,57,60,54
|
62 |
+
]
|
63 |
+
|
64 |
+
pt68 = []
|
65 |
+
for i in range(68):
|
66 |
+
index = map106to68[i]
|
67 |
+
pt68.append(pt106[index])
|
68 |
+
return pt68
|
69 |
+
|
70 |
+
|
71 |
+
|
72 |
+
|
73 |
+
def check_bbox(self, imgs, boxes):
|
74 |
+
boxes = boxes.view(-1, 4, 4)
|
75 |
+
colors = [(0, 255, 0), (0, 255, 0), (255, 255, 0), (255, 0, 0)]
|
76 |
+
i = 0
|
77 |
+
for img, box in zip(imgs, boxes):
|
78 |
+
img = (img + 1)/2 * 255
|
79 |
+
img2 = img.permute(1, 2, 0).float().cpu().flip(2).numpy().copy()
|
80 |
+
for idx, point in enumerate(box):
|
81 |
+
cv2.rectangle(img2, (int(point[0]), int(point[1])), (int(point[2]), int(point[3])), color=colors[idx], thickness=2)
|
82 |
+
cv2.imwrite('dmdnet_{:02d}.png'.format(i), img2)
|
83 |
+
i += 1
|
84 |
+
|
85 |
+
|
86 |
+
def trans_points2d(self, pts, M):
|
87 |
+
new_pts = np.zeros(shape=pts.shape, dtype=np.float32)
|
88 |
+
for i in range(pts.shape[0]):
|
89 |
+
pt = pts[i]
|
90 |
+
new_pt = np.array([pt[0], pt[1], 1.0], dtype=np.float32)
|
91 |
+
new_pt = np.dot(M, new_pt)
|
92 |
+
new_pts[i] = new_pt[0:2]
|
93 |
+
|
94 |
+
return new_pts
|
95 |
+
|
96 |
+
|
97 |
+
def enhance_face(self, ref_faceset: FaceSet, temp_frame, face: Face):
|
98 |
+
# preprocess
|
99 |
+
start_x, start_y, end_x, end_y = map(int, face['bbox'])
|
100 |
+
lm106 = face.landmark_2d_106
|
101 |
+
lq_landmarks = np.asarray(self.landmarks106_to_68(lm106))
|
102 |
+
|
103 |
+
if temp_frame.shape[0] != 512 or temp_frame.shape[1] != 512:
|
104 |
+
# scale to 512x512
|
105 |
+
scale_factor = 512 / temp_frame.shape[1]
|
106 |
+
|
107 |
+
M = face.matrix * scale_factor
|
108 |
+
|
109 |
+
lq_landmarks = self.trans_points2d(lq_landmarks, M)
|
110 |
+
temp_frame = cv2.resize(temp_frame, (512,512), interpolation = cv2.INTER_AREA)
|
111 |
+
|
112 |
+
if temp_frame.ndim == 2:
|
113 |
+
temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_GRAY2RGB) # GGG
|
114 |
+
# else:
|
115 |
+
# temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB) # RGB
|
116 |
+
|
117 |
+
lq = read_img_tensor(temp_frame)
|
118 |
+
|
119 |
+
LQLocs = get_component_location(lq_landmarks)
|
120 |
+
# self.check_bbox(lq, LQLocs.unsqueeze(0))
|
121 |
+
|
122 |
+
# specific, change 1000 to 1 to activate
|
123 |
+
if len(ref_faceset.faces) > 1:
|
124 |
+
SpecificImgs = []
|
125 |
+
SpecificLocs = []
|
126 |
+
for i,face in enumerate(ref_faceset.faces):
|
127 |
+
lm106 = face.landmark_2d_106
|
128 |
+
lq_landmarks = np.asarray(self.landmarks106_to_68(lm106))
|
129 |
+
ref_image = ref_faceset.ref_images[i]
|
130 |
+
if ref_image.shape[0] != 512 or ref_image.shape[1] != 512:
|
131 |
+
# scale to 512x512
|
132 |
+
scale_factor = 512 / ref_image.shape[1]
|
133 |
+
|
134 |
+
M = face.matrix * scale_factor
|
135 |
+
|
136 |
+
lq_landmarks = self.trans_points2d(lq_landmarks, M)
|
137 |
+
ref_image = cv2.resize(ref_image, (512,512), interpolation = cv2.INTER_AREA)
|
138 |
+
|
139 |
+
if ref_image.ndim == 2:
|
140 |
+
temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_GRAY2RGB) # GGG
|
141 |
+
# else:
|
142 |
+
# temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB) # RGB
|
143 |
+
|
144 |
+
ref_tensor = read_img_tensor(ref_image)
|
145 |
+
ref_locs = get_component_location(lq_landmarks)
|
146 |
+
# self.check_bbox(ref_tensor, ref_locs.unsqueeze(0))
|
147 |
+
|
148 |
+
SpecificImgs.append(ref_tensor)
|
149 |
+
SpecificLocs.append(ref_locs.unsqueeze(0))
|
150 |
+
|
151 |
+
SpecificImgs = torch.cat(SpecificImgs, dim=0)
|
152 |
+
SpecificLocs = torch.cat(SpecificLocs, dim=0)
|
153 |
+
# check_bbox(SpecificImgs, SpecificLocs)
|
154 |
+
SpMem256, SpMem128, SpMem64 = self.model_dmdnet.generate_specific_dictionary(sp_imgs = SpecificImgs.to(self.torchdevice), sp_locs = SpecificLocs)
|
155 |
+
SpMem256Para = {}
|
156 |
+
SpMem128Para = {}
|
157 |
+
SpMem64Para = {}
|
158 |
+
for k, v in SpMem256.items():
|
159 |
+
SpMem256Para[k] = v
|
160 |
+
for k, v in SpMem128.items():
|
161 |
+
SpMem128Para[k] = v
|
162 |
+
for k, v in SpMem64.items():
|
163 |
+
SpMem64Para[k] = v
|
164 |
+
else:
|
165 |
+
# generic
|
166 |
+
SpMem256Para, SpMem128Para, SpMem64Para = None, None, None
|
167 |
+
|
168 |
+
with torch.no_grad():
|
169 |
+
with THREAD_LOCK_DMDNET:
|
170 |
+
try:
|
171 |
+
GenericResult, SpecificResult = self.model_dmdnet(lq = lq.to(self.torchdevice), loc = LQLocs.unsqueeze(0), sp_256 = SpMem256Para, sp_128 = SpMem128Para, sp_64 = SpMem64Para)
|
172 |
+
except Exception as e:
|
173 |
+
print(f'Error {e} there may be something wrong with the detected component locations.')
|
174 |
+
return temp_frame
|
175 |
+
|
176 |
+
if SpecificResult is not None:
|
177 |
+
save_specific = SpecificResult * 0.5 + 0.5
|
178 |
+
save_specific = save_specific.squeeze(0).permute(1, 2, 0).flip(2) # RGB->BGR
|
179 |
+
save_specific = np.clip(save_specific.float().cpu().numpy(), 0, 1) * 255.0
|
180 |
+
temp_frame = save_specific.astype("uint8")
|
181 |
+
if False:
|
182 |
+
save_generic = GenericResult * 0.5 + 0.5
|
183 |
+
save_generic = save_generic.squeeze(0).permute(1, 2, 0).flip(2) # RGB->BGR
|
184 |
+
save_generic = np.clip(save_generic.float().cpu().numpy(), 0, 1) * 255.0
|
185 |
+
check_lq = lq * 0.5 + 0.5
|
186 |
+
check_lq = check_lq.squeeze(0).permute(1, 2, 0).flip(2) # RGB->BGR
|
187 |
+
check_lq = np.clip(check_lq.float().cpu().numpy(), 0, 1) * 255.0
|
188 |
+
cv2.imwrite('dmdnet_comparison.png', cv2.cvtColor(np.hstack((check_lq, save_generic, save_specific)),cv2.COLOR_RGB2BGR))
|
189 |
+
else:
|
190 |
+
save_generic = GenericResult * 0.5 + 0.5
|
191 |
+
save_generic = save_generic.squeeze(0).permute(1, 2, 0).flip(2) # RGB->BGR
|
192 |
+
save_generic = np.clip(save_generic.float().cpu().numpy(), 0, 1) * 255.0
|
193 |
+
temp_frame = save_generic.astype("uint8")
|
194 |
+
temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_RGB2BGR) # RGB
|
195 |
+
return temp_frame
|
196 |
+
|
197 |
+
|
198 |
+
|
199 |
+
def create(self, devicename):
|
200 |
+
self.torchdevice = torch.device(devicename)
|
201 |
+
model_dmdnet = DMDNet().to(self.torchdevice)
|
202 |
+
weights = torch.load('./models/DMDNet.pth')
|
203 |
+
model_dmdnet.load_state_dict(weights, strict=True)
|
204 |
+
|
205 |
+
model_dmdnet.eval()
|
206 |
+
num_params = 0
|
207 |
+
for param in model_dmdnet.parameters():
|
208 |
+
num_params += param.numel()
|
209 |
+
return model_dmdnet
|
210 |
+
|
211 |
+
# print('{:>8s} : {}'.format('Using device', device))
|
212 |
+
# print('{:>8s} : {:.2f}M'.format('Model params', num_params/1e6))
|
213 |
+
|
214 |
+
|
215 |
+
|
216 |
+
def read_img_tensor(Img=None): #rgb -1~1
|
217 |
+
Img = Img.transpose((2, 0, 1))/255.0
|
218 |
+
Img = torch.from_numpy(Img).float()
|
219 |
+
normalize(Img, [0.5,0.5,0.5], [0.5,0.5,0.5], inplace=True)
|
220 |
+
ImgTensor = Img.unsqueeze(0)
|
221 |
+
return ImgTensor
|
222 |
+
|
223 |
+
|
224 |
+
def get_component_location(Landmarks, re_read=False):
|
225 |
+
if re_read:
|
226 |
+
ReadLandmark = []
|
227 |
+
with open(Landmarks,'r') as f:
|
228 |
+
for line in f:
|
229 |
+
tmp = [float(i) for i in line.split(' ') if i != '\n']
|
230 |
+
ReadLandmark.append(tmp)
|
231 |
+
ReadLandmark = np.array(ReadLandmark) #
|
232 |
+
Landmarks = np.reshape(ReadLandmark, [-1, 2]) # 68*2
|
233 |
+
Map_LE_B = list(np.hstack((range(17,22), range(36,42))))
|
234 |
+
Map_RE_B = list(np.hstack((range(22,27), range(42,48))))
|
235 |
+
Map_LE = list(range(36,42))
|
236 |
+
Map_RE = list(range(42,48))
|
237 |
+
Map_NO = list(range(29,36))
|
238 |
+
Map_MO = list(range(48,68))
|
239 |
+
|
240 |
+
Landmarks[Landmarks>504]=504
|
241 |
+
Landmarks[Landmarks<8]=8
|
242 |
+
|
243 |
+
#left eye
|
244 |
+
Mean_LE = np.mean(Landmarks[Map_LE],0)
|
245 |
+
L_LE1 = Mean_LE[1] - np.min(Landmarks[Map_LE_B,1])
|
246 |
+
L_LE1 = L_LE1 * 1.3
|
247 |
+
L_LE2 = L_LE1 / 1.9
|
248 |
+
L_LE_xy = L_LE1 + L_LE2
|
249 |
+
L_LE_lt = [L_LE_xy/2, L_LE1]
|
250 |
+
L_LE_rb = [L_LE_xy/2, L_LE2]
|
251 |
+
Location_LE = np.hstack((Mean_LE - L_LE_lt + 1, Mean_LE + L_LE_rb)).astype(int)
|
252 |
+
|
253 |
+
#right eye
|
254 |
+
Mean_RE = np.mean(Landmarks[Map_RE],0)
|
255 |
+
L_RE1 = Mean_RE[1] - np.min(Landmarks[Map_RE_B,1])
|
256 |
+
L_RE1 = L_RE1 * 1.3
|
257 |
+
L_RE2 = L_RE1 / 1.9
|
258 |
+
L_RE_xy = L_RE1 + L_RE2
|
259 |
+
L_RE_lt = [L_RE_xy/2, L_RE1]
|
260 |
+
L_RE_rb = [L_RE_xy/2, L_RE2]
|
261 |
+
Location_RE = np.hstack((Mean_RE - L_RE_lt + 1, Mean_RE + L_RE_rb)).astype(int)
|
262 |
+
|
263 |
+
#nose
|
264 |
+
Mean_NO = np.mean(Landmarks[Map_NO],0)
|
265 |
+
L_NO1 =( np.max([Mean_NO[0] - Landmarks[31][0], Landmarks[35][0] - Mean_NO[0]])) * 1.25
|
266 |
+
L_NO2 = (Landmarks[33][1] - Mean_NO[1]) * 1.1
|
267 |
+
L_NO_xy = L_NO1 * 2
|
268 |
+
L_NO_lt = [L_NO_xy/2, L_NO_xy - L_NO2]
|
269 |
+
L_NO_rb = [L_NO_xy/2, L_NO2]
|
270 |
+
Location_NO = np.hstack((Mean_NO - L_NO_lt + 1, Mean_NO + L_NO_rb)).astype(int)
|
271 |
+
|
272 |
+
#mouth
|
273 |
+
Mean_MO = np.mean(Landmarks[Map_MO],0)
|
274 |
+
L_MO = np.max((np.max(np.max(Landmarks[Map_MO],0) - np.min(Landmarks[Map_MO],0))/2,16)) * 1.1
|
275 |
+
MO_O = Mean_MO - L_MO + 1
|
276 |
+
MO_T = Mean_MO + L_MO
|
277 |
+
MO_T[MO_T>510]=510
|
278 |
+
Location_MO = np.hstack((MO_O, MO_T)).astype(int)
|
279 |
+
return torch.cat([torch.FloatTensor(Location_LE).unsqueeze(0), torch.FloatTensor(Location_RE).unsqueeze(0), torch.FloatTensor(Location_NO).unsqueeze(0), torch.FloatTensor(Location_MO).unsqueeze(0)], dim=0)
|
280 |
+
|
281 |
+
|
282 |
+
|
283 |
+
|
284 |
+
def calc_mean_std_4D(feat, eps=1e-5):
|
285 |
+
# eps is a small value added to the variance to avoid divide-by-zero.
|
286 |
+
size = feat.size()
|
287 |
+
assert (len(size) == 4)
|
288 |
+
N, C = size[:2]
|
289 |
+
feat_var = feat.view(N, C, -1).var(dim=2) + eps
|
290 |
+
feat_std = feat_var.sqrt().view(N, C, 1, 1)
|
291 |
+
feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
|
292 |
+
return feat_mean, feat_std
|
293 |
+
|
294 |
+
def adaptive_instance_normalization_4D(content_feat, style_feat): # content_feat is ref feature, style is degradate feature
|
295 |
+
size = content_feat.size()
|
296 |
+
style_mean, style_std = calc_mean_std_4D(style_feat)
|
297 |
+
content_mean, content_std = calc_mean_std_4D(content_feat)
|
298 |
+
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
|
299 |
+
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
|
300 |
+
|
301 |
+
|
302 |
+
def convU(in_channels, out_channels,conv_layer, norm_layer, kernel_size=3, stride=1,dilation=1, bias=True):
|
303 |
+
return nn.Sequential(
|
304 |
+
SpectralNorm(conv_layer(in_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=((kernel_size-1)//2)*dilation, bias=bias)),
|
305 |
+
nn.LeakyReLU(0.2),
|
306 |
+
SpectralNorm(conv_layer(out_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=((kernel_size-1)//2)*dilation, bias=bias)),
|
307 |
+
)
|
308 |
+
|
309 |
+
|
310 |
+
class MSDilateBlock(nn.Module):
|
311 |
+
def __init__(self, in_channels,conv_layer=nn.Conv2d, norm_layer=nn.BatchNorm2d, kernel_size=3, dilation=[1,1,1,1], bias=True):
|
312 |
+
super(MSDilateBlock, self).__init__()
|
313 |
+
self.conv1 = convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[0], bias=bias)
|
314 |
+
self.conv2 = convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[1], bias=bias)
|
315 |
+
self.conv3 = convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[2], bias=bias)
|
316 |
+
self.conv4 = convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[3], bias=bias)
|
317 |
+
self.convi = SpectralNorm(conv_layer(in_channels*4, in_channels, kernel_size=kernel_size, stride=1, padding=(kernel_size-1)//2, bias=bias))
|
318 |
+
def forward(self, x):
|
319 |
+
conv1 = self.conv1(x)
|
320 |
+
conv2 = self.conv2(x)
|
321 |
+
conv3 = self.conv3(x)
|
322 |
+
conv4 = self.conv4(x)
|
323 |
+
cat = torch.cat([conv1, conv2, conv3, conv4], 1)
|
324 |
+
out = self.convi(cat) + x
|
325 |
+
return out
|
326 |
+
|
327 |
+
|
328 |
+
class AdaptiveInstanceNorm(nn.Module):
|
329 |
+
def __init__(self, in_channel):
|
330 |
+
super().__init__()
|
331 |
+
self.norm = nn.InstanceNorm2d(in_channel)
|
332 |
+
|
333 |
+
def forward(self, input, style):
|
334 |
+
style_mean, style_std = calc_mean_std_4D(style)
|
335 |
+
out = self.norm(input)
|
336 |
+
size = input.size()
|
337 |
+
out = style_std.expand(size) * out + style_mean.expand(size)
|
338 |
+
return out
|
339 |
+
|
340 |
+
class NoiseInjection(nn.Module):
|
341 |
+
def __init__(self, channel):
|
342 |
+
super().__init__()
|
343 |
+
self.weight = nn.Parameter(torch.zeros(1, channel, 1, 1))
|
344 |
+
def forward(self, image, noise):
|
345 |
+
if noise is None:
|
346 |
+
b, c, h, w = image.shape
|
347 |
+
noise = image.new_empty(b, 1, h, w).normal_()
|
348 |
+
return image + self.weight * noise
|
349 |
+
|
350 |
+
class StyledUpBlock(nn.Module):
|
351 |
+
def __init__(self, in_channel, out_channel, kernel_size=3, padding=1,upsample=False, noise_inject=False):
|
352 |
+
super().__init__()
|
353 |
+
|
354 |
+
self.noise_inject = noise_inject
|
355 |
+
if upsample:
|
356 |
+
self.conv1 = nn.Sequential(
|
357 |
+
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
|
358 |
+
SpectralNorm(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding)),
|
359 |
+
nn.LeakyReLU(0.2),
|
360 |
+
)
|
361 |
+
else:
|
362 |
+
self.conv1 = nn.Sequential(
|
363 |
+
SpectralNorm(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding)),
|
364 |
+
nn.LeakyReLU(0.2),
|
365 |
+
SpectralNorm(nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding)),
|
366 |
+
)
|
367 |
+
self.convup = nn.Sequential(
|
368 |
+
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
|
369 |
+
SpectralNorm(nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding)),
|
370 |
+
nn.LeakyReLU(0.2),
|
371 |
+
SpectralNorm(nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding)),
|
372 |
+
)
|
373 |
+
if self.noise_inject:
|
374 |
+
self.noise1 = NoiseInjection(out_channel)
|
375 |
+
|
376 |
+
self.lrelu1 = nn.LeakyReLU(0.2)
|
377 |
+
|
378 |
+
self.ScaleModel1 = nn.Sequential(
|
379 |
+
SpectralNorm(nn.Conv2d(in_channel,out_channel,3, 1, 1)),
|
380 |
+
nn.LeakyReLU(0.2),
|
381 |
+
SpectralNorm(nn.Conv2d(out_channel, out_channel, 3, 1, 1))
|
382 |
+
)
|
383 |
+
self.ShiftModel1 = nn.Sequential(
|
384 |
+
SpectralNorm(nn.Conv2d(in_channel,out_channel,3, 1, 1)),
|
385 |
+
nn.LeakyReLU(0.2),
|
386 |
+
SpectralNorm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)),
|
387 |
+
)
|
388 |
+
|
389 |
+
def forward(self, input, style):
|
390 |
+
out = self.conv1(input)
|
391 |
+
out = self.lrelu1(out)
|
392 |
+
Shift1 = self.ShiftModel1(style)
|
393 |
+
Scale1 = self.ScaleModel1(style)
|
394 |
+
out = out * Scale1 + Shift1
|
395 |
+
if self.noise_inject:
|
396 |
+
out = self.noise1(out, noise=None)
|
397 |
+
outup = self.convup(out)
|
398 |
+
return outup
|
399 |
+
|
400 |
+
|
401 |
+
####################################################################
|
402 |
+
###############Face Dictionary Generator
|
403 |
+
####################################################################
|
404 |
+
def AttentionBlock(in_channel):
|
405 |
+
return nn.Sequential(
|
406 |
+
SpectralNorm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)),
|
407 |
+
nn.LeakyReLU(0.2),
|
408 |
+
SpectralNorm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)),
|
409 |
+
)
|
410 |
+
|
411 |
+
class DilateResBlock(nn.Module):
|
412 |
+
def __init__(self, dim, dilation=[5,3] ):
|
413 |
+
super(DilateResBlock, self).__init__()
|
414 |
+
self.Res = nn.Sequential(
|
415 |
+
SpectralNorm(nn.Conv2d(dim, dim, 3, 1, ((3-1)//2)*dilation[0], dilation[0])),
|
416 |
+
nn.LeakyReLU(0.2),
|
417 |
+
SpectralNorm(nn.Conv2d(dim, dim, 3, 1, ((3-1)//2)*dilation[1], dilation[1])),
|
418 |
+
)
|
419 |
+
def forward(self, x):
|
420 |
+
out = x + self.Res(x)
|
421 |
+
return out
|
422 |
+
|
423 |
+
|
424 |
+
class KeyValue(nn.Module):
|
425 |
+
def __init__(self, indim, keydim, valdim):
|
426 |
+
super(KeyValue, self).__init__()
|
427 |
+
self.Key = nn.Sequential(
|
428 |
+
SpectralNorm(nn.Conv2d(indim, keydim, kernel_size=(3,3), padding=(1,1), stride=1)),
|
429 |
+
nn.LeakyReLU(0.2),
|
430 |
+
SpectralNorm(nn.Conv2d(keydim, keydim, kernel_size=(3,3), padding=(1,1), stride=1)),
|
431 |
+
)
|
432 |
+
self.Value = nn.Sequential(
|
433 |
+
SpectralNorm(nn.Conv2d(indim, valdim, kernel_size=(3,3), padding=(1,1), stride=1)),
|
434 |
+
nn.LeakyReLU(0.2),
|
435 |
+
SpectralNorm(nn.Conv2d(valdim, valdim, kernel_size=(3,3), padding=(1,1), stride=1)),
|
436 |
+
)
|
437 |
+
def forward(self, x):
|
438 |
+
return self.Key(x), self.Value(x)
|
439 |
+
|
440 |
+
class MaskAttention(nn.Module):
|
441 |
+
def __init__(self, indim):
|
442 |
+
super(MaskAttention, self).__init__()
|
443 |
+
self.conv1 = nn.Sequential(
|
444 |
+
SpectralNorm(nn.Conv2d(indim, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)),
|
445 |
+
nn.LeakyReLU(0.2),
|
446 |
+
SpectralNorm(nn.Conv2d(indim//3, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)),
|
447 |
+
)
|
448 |
+
self.conv2 = nn.Sequential(
|
449 |
+
SpectralNorm(nn.Conv2d(indim, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)),
|
450 |
+
nn.LeakyReLU(0.2),
|
451 |
+
SpectralNorm(nn.Conv2d(indim//3, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)),
|
452 |
+
)
|
453 |
+
self.conv3 = nn.Sequential(
|
454 |
+
SpectralNorm(nn.Conv2d(indim, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)),
|
455 |
+
nn.LeakyReLU(0.2),
|
456 |
+
SpectralNorm(nn.Conv2d(indim//3, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)),
|
457 |
+
)
|
458 |
+
self.convCat = nn.Sequential(
|
459 |
+
SpectralNorm(nn.Conv2d(indim//3 * 3, indim, kernel_size=(3,3), padding=(1,1), stride=1)),
|
460 |
+
nn.LeakyReLU(0.2),
|
461 |
+
SpectralNorm(nn.Conv2d(indim, indim, kernel_size=(3,3), padding=(1,1), stride=1)),
|
462 |
+
)
|
463 |
+
def forward(self, x, y, z):
|
464 |
+
c1 = self.conv1(x)
|
465 |
+
c2 = self.conv2(y)
|
466 |
+
c3 = self.conv3(z)
|
467 |
+
return self.convCat(torch.cat([c1,c2,c3], dim=1))
|
468 |
+
|
469 |
+
class Query(nn.Module):
|
470 |
+
def __init__(self, indim, quedim):
|
471 |
+
super(Query, self).__init__()
|
472 |
+
self.Query = nn.Sequential(
|
473 |
+
SpectralNorm(nn.Conv2d(indim, quedim, kernel_size=(3,3), padding=(1,1), stride=1)),
|
474 |
+
nn.LeakyReLU(0.2),
|
475 |
+
SpectralNorm(nn.Conv2d(quedim, quedim, kernel_size=(3,3), padding=(1,1), stride=1)),
|
476 |
+
)
|
477 |
+
def forward(self, x):
|
478 |
+
return self.Query(x)
|
479 |
+
|
480 |
+
def roi_align_self(input, location, target_size):
|
481 |
+
test = (target_size.item(),target_size.item())
|
482 |
+
return torch.cat([F.interpolate(input[i:i+1,:,location[i,1]:location[i,3],location[i,0]:location[i,2]],test,mode='bilinear',align_corners=False) for i in range(input.size(0))],0)
|
483 |
+
|
484 |
+
class FeatureExtractor(nn.Module):
|
485 |
+
def __init__(self, ngf = 64, key_scale = 4):#
|
486 |
+
super().__init__()
|
487 |
+
|
488 |
+
self.key_scale = 4
|
489 |
+
self.part_sizes = np.array([80,80,50,110]) #
|
490 |
+
self.feature_sizes = np.array([256,128,64]) #
|
491 |
+
|
492 |
+
self.conv1 = nn.Sequential(
|
493 |
+
SpectralNorm(nn.Conv2d(3, ngf, 3, 2, 1)),
|
494 |
+
nn.LeakyReLU(0.2),
|
495 |
+
SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1)),
|
496 |
+
)
|
497 |
+
self.conv2 = nn.Sequential(
|
498 |
+
SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1)),
|
499 |
+
nn.LeakyReLU(0.2),
|
500 |
+
SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1))
|
501 |
+
)
|
502 |
+
self.res1 = DilateResBlock(ngf, [5,3])
|
503 |
+
self.res2 = DilateResBlock(ngf, [5,3])
|
504 |
+
|
505 |
+
|
506 |
+
self.conv3 = nn.Sequential(
|
507 |
+
SpectralNorm(nn.Conv2d(ngf, ngf*2, 3, 2, 1)),
|
508 |
+
nn.LeakyReLU(0.2),
|
509 |
+
SpectralNorm(nn.Conv2d(ngf*2, ngf*2, 3, 1, 1)),
|
510 |
+
)
|
511 |
+
self.conv4 = nn.Sequential(
|
512 |
+
SpectralNorm(nn.Conv2d(ngf*2, ngf*2, 3, 1, 1)),
|
513 |
+
nn.LeakyReLU(0.2),
|
514 |
+
SpectralNorm(nn.Conv2d(ngf*2, ngf*2, 3, 1, 1))
|
515 |
+
)
|
516 |
+
self.res3 = DilateResBlock(ngf*2, [3,1])
|
517 |
+
self.res4 = DilateResBlock(ngf*2, [3,1])
|
518 |
+
|
519 |
+
self.conv5 = nn.Sequential(
|
520 |
+
SpectralNorm(nn.Conv2d(ngf*2, ngf*4, 3, 2, 1)),
|
521 |
+
nn.LeakyReLU(0.2),
|
522 |
+
SpectralNorm(nn.Conv2d(ngf*4, ngf*4, 3, 1, 1)),
|
523 |
+
)
|
524 |
+
self.conv6 = nn.Sequential(
|
525 |
+
SpectralNorm(nn.Conv2d(ngf*4, ngf*4, 3, 1, 1)),
|
526 |
+
nn.LeakyReLU(0.2),
|
527 |
+
SpectralNorm(nn.Conv2d(ngf*4, ngf*4, 3, 1, 1))
|
528 |
+
)
|
529 |
+
self.res5 = DilateResBlock(ngf*4, [1,1])
|
530 |
+
self.res6 = DilateResBlock(ngf*4, [1,1])
|
531 |
+
|
532 |
+
self.LE_256_Q = Query(ngf, ngf // self.key_scale)
|
533 |
+
self.RE_256_Q = Query(ngf, ngf // self.key_scale)
|
534 |
+
self.MO_256_Q = Query(ngf, ngf // self.key_scale)
|
535 |
+
self.LE_128_Q = Query(ngf * 2, ngf * 2 // self.key_scale)
|
536 |
+
self.RE_128_Q = Query(ngf * 2, ngf * 2 // self.key_scale)
|
537 |
+
self.MO_128_Q = Query(ngf * 2, ngf * 2 // self.key_scale)
|
538 |
+
self.LE_64_Q = Query(ngf * 4, ngf * 4 // self.key_scale)
|
539 |
+
self.RE_64_Q = Query(ngf * 4, ngf * 4 // self.key_scale)
|
540 |
+
self.MO_64_Q = Query(ngf * 4, ngf * 4 // self.key_scale)
|
541 |
+
|
542 |
+
|
543 |
+
def forward(self, img, locs):
|
544 |
+
le_location = locs[:,0,:].int().cpu().numpy()
|
545 |
+
re_location = locs[:,1,:].int().cpu().numpy()
|
546 |
+
no_location = locs[:,2,:].int().cpu().numpy()
|
547 |
+
mo_location = locs[:,3,:].int().cpu().numpy()
|
548 |
+
|
549 |
+
|
550 |
+
f1_0 = self.conv1(img)
|
551 |
+
f1_1 = self.res1(f1_0)
|
552 |
+
f2_0 = self.conv2(f1_1)
|
553 |
+
f2_1 = self.res2(f2_0)
|
554 |
+
|
555 |
+
f3_0 = self.conv3(f2_1)
|
556 |
+
f3_1 = self.res3(f3_0)
|
557 |
+
f4_0 = self.conv4(f3_1)
|
558 |
+
f4_1 = self.res4(f4_0)
|
559 |
+
|
560 |
+
f5_0 = self.conv5(f4_1)
|
561 |
+
f5_1 = self.res5(f5_0)
|
562 |
+
f6_0 = self.conv6(f5_1)
|
563 |
+
f6_1 = self.res6(f6_0)
|
564 |
+
|
565 |
+
|
566 |
+
####ROI Align
|
567 |
+
le_part_256 = roi_align_self(f2_1.clone(), le_location//2, self.part_sizes[0]//2)
|
568 |
+
re_part_256 = roi_align_self(f2_1.clone(), re_location//2, self.part_sizes[1]//2)
|
569 |
+
mo_part_256 = roi_align_self(f2_1.clone(), mo_location//2, self.part_sizes[3]//2)
|
570 |
+
|
571 |
+
le_part_128 = roi_align_self(f4_1.clone(), le_location//4, self.part_sizes[0]//4)
|
572 |
+
re_part_128 = roi_align_self(f4_1.clone(), re_location//4, self.part_sizes[1]//4)
|
573 |
+
mo_part_128 = roi_align_self(f4_1.clone(), mo_location//4, self.part_sizes[3]//4)
|
574 |
+
|
575 |
+
le_part_64 = roi_align_self(f6_1.clone(), le_location//8, self.part_sizes[0]//8)
|
576 |
+
re_part_64 = roi_align_self(f6_1.clone(), re_location//8, self.part_sizes[1]//8)
|
577 |
+
mo_part_64 = roi_align_self(f6_1.clone(), mo_location//8, self.part_sizes[3]//8)
|
578 |
+
|
579 |
+
|
580 |
+
le_256_q = self.LE_256_Q(le_part_256)
|
581 |
+
re_256_q = self.RE_256_Q(re_part_256)
|
582 |
+
mo_256_q = self.MO_256_Q(mo_part_256)
|
583 |
+
|
584 |
+
le_128_q = self.LE_128_Q(le_part_128)
|
585 |
+
re_128_q = self.RE_128_Q(re_part_128)
|
586 |
+
mo_128_q = self.MO_128_Q(mo_part_128)
|
587 |
+
|
588 |
+
le_64_q = self.LE_64_Q(le_part_64)
|
589 |
+
re_64_q = self.RE_64_Q(re_part_64)
|
590 |
+
mo_64_q = self.MO_64_Q(mo_part_64)
|
591 |
+
|
592 |
+
return {'f256': f2_1, 'f128': f4_1, 'f64': f6_1,\
|
593 |
+
'le256': le_part_256, 're256': re_part_256, 'mo256': mo_part_256, \
|
594 |
+
'le128': le_part_128, 're128': re_part_128, 'mo128': mo_part_128, \
|
595 |
+
'le64': le_part_64, 're64': re_part_64, 'mo64': mo_part_64, \
|
596 |
+
'le_256_q': le_256_q, 're_256_q': re_256_q, 'mo_256_q': mo_256_q,\
|
597 |
+
'le_128_q': le_128_q, 're_128_q': re_128_q, 'mo_128_q': mo_128_q,\
|
598 |
+
'le_64_q': le_64_q, 're_64_q': re_64_q, 'mo_64_q': mo_64_q}
|
599 |
+
|
600 |
+
|
601 |
+
class DMDNet(nn.Module):
|
602 |
+
def __init__(self, ngf = 64, banks_num = 128):
|
603 |
+
super().__init__()
|
604 |
+
self.part_sizes = np.array([80,80,50,110]) # size for 512
|
605 |
+
self.feature_sizes = np.array([256,128,64]) # size for 512
|
606 |
+
|
607 |
+
self.banks_num = banks_num
|
608 |
+
self.key_scale = 4
|
609 |
+
|
610 |
+
self.E_lq = FeatureExtractor(key_scale = self.key_scale)
|
611 |
+
self.E_hq = FeatureExtractor(key_scale = self.key_scale)
|
612 |
+
|
613 |
+
self.LE_256_KV = KeyValue(ngf, ngf // self.key_scale, ngf)
|
614 |
+
self.RE_256_KV = KeyValue(ngf, ngf // self.key_scale, ngf)
|
615 |
+
self.MO_256_KV = KeyValue(ngf, ngf // self.key_scale, ngf)
|
616 |
+
|
617 |
+
self.LE_128_KV = KeyValue(ngf * 2 , ngf * 2 // self.key_scale, ngf * 2)
|
618 |
+
self.RE_128_KV = KeyValue(ngf * 2 , ngf * 2 // self.key_scale, ngf * 2)
|
619 |
+
self.MO_128_KV = KeyValue(ngf * 2 , ngf * 2 // self.key_scale, ngf * 2)
|
620 |
+
|
621 |
+
self.LE_64_KV = KeyValue(ngf * 4 , ngf * 4 // self.key_scale, ngf * 4)
|
622 |
+
self.RE_64_KV = KeyValue(ngf * 4 , ngf * 4 // self.key_scale, ngf * 4)
|
623 |
+
self.MO_64_KV = KeyValue(ngf * 4 , ngf * 4 // self.key_scale, ngf * 4)
|
624 |
+
|
625 |
+
|
626 |
+
self.LE_256_Attention = AttentionBlock(64)
|
627 |
+
self.RE_256_Attention = AttentionBlock(64)
|
628 |
+
self.MO_256_Attention = AttentionBlock(64)
|
629 |
+
|
630 |
+
self.LE_128_Attention = AttentionBlock(128)
|
631 |
+
self.RE_128_Attention = AttentionBlock(128)
|
632 |
+
self.MO_128_Attention = AttentionBlock(128)
|
633 |
+
|
634 |
+
self.LE_64_Attention = AttentionBlock(256)
|
635 |
+
self.RE_64_Attention = AttentionBlock(256)
|
636 |
+
self.MO_64_Attention = AttentionBlock(256)
|
637 |
+
|
638 |
+
self.LE_256_Mask = MaskAttention(64)
|
639 |
+
self.RE_256_Mask = MaskAttention(64)
|
640 |
+
self.MO_256_Mask = MaskAttention(64)
|
641 |
+
|
642 |
+
self.LE_128_Mask = MaskAttention(128)
|
643 |
+
self.RE_128_Mask = MaskAttention(128)
|
644 |
+
self.MO_128_Mask = MaskAttention(128)
|
645 |
+
|
646 |
+
self.LE_64_Mask = MaskAttention(256)
|
647 |
+
self.RE_64_Mask = MaskAttention(256)
|
648 |
+
self.MO_64_Mask = MaskAttention(256)
|
649 |
+
|
650 |
+
self.MSDilate = MSDilateBlock(ngf*4, dilation = [4,3,2,1])
|
651 |
+
|
652 |
+
self.up1 = StyledUpBlock(ngf*4, ngf*2, noise_inject=False) #
|
653 |
+
self.up2 = StyledUpBlock(ngf*2, ngf, noise_inject=False) #
|
654 |
+
self.up3 = StyledUpBlock(ngf, ngf, noise_inject=False) #
|
655 |
+
self.up4 = nn.Sequential(
|
656 |
+
SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1)),
|
657 |
+
nn.LeakyReLU(0.2),
|
658 |
+
UpResBlock(ngf),
|
659 |
+
UpResBlock(ngf),
|
660 |
+
SpectralNorm(nn.Conv2d(ngf, 3, kernel_size=3, stride=1, padding=1)),
|
661 |
+
nn.Tanh()
|
662 |
+
)
|
663 |
+
|
664 |
+
# define generic memory, revise register_buffer to register_parameter for backward update
|
665 |
+
self.register_buffer('le_256_mem_key', torch.randn(128,16,40,40))
|
666 |
+
self.register_buffer('re_256_mem_key', torch.randn(128,16,40,40))
|
667 |
+
self.register_buffer('mo_256_mem_key', torch.randn(128,16,55,55))
|
668 |
+
self.register_buffer('le_256_mem_value', torch.randn(128,64,40,40))
|
669 |
+
self.register_buffer('re_256_mem_value', torch.randn(128,64,40,40))
|
670 |
+
self.register_buffer('mo_256_mem_value', torch.randn(128,64,55,55))
|
671 |
+
|
672 |
+
|
673 |
+
self.register_buffer('le_128_mem_key', torch.randn(128,32,20,20))
|
674 |
+
self.register_buffer('re_128_mem_key', torch.randn(128,32,20,20))
|
675 |
+
self.register_buffer('mo_128_mem_key', torch.randn(128,32,27,27))
|
676 |
+
self.register_buffer('le_128_mem_value', torch.randn(128,128,20,20))
|
677 |
+
self.register_buffer('re_128_mem_value', torch.randn(128,128,20,20))
|
678 |
+
self.register_buffer('mo_128_mem_value', torch.randn(128,128,27,27))
|
679 |
+
|
680 |
+
self.register_buffer('le_64_mem_key', torch.randn(128,64,10,10))
|
681 |
+
self.register_buffer('re_64_mem_key', torch.randn(128,64,10,10))
|
682 |
+
self.register_buffer('mo_64_mem_key', torch.randn(128,64,13,13))
|
683 |
+
self.register_buffer('le_64_mem_value', torch.randn(128,256,10,10))
|
684 |
+
self.register_buffer('re_64_mem_value', torch.randn(128,256,10,10))
|
685 |
+
self.register_buffer('mo_64_mem_value', torch.randn(128,256,13,13))
|
686 |
+
|
687 |
+
|
688 |
+
def readMem(self, k, v, q):
|
689 |
+
sim = F.conv2d(q, k)
|
690 |
+
score = F.softmax(sim/sqrt(sim.size(1)), dim=1) #B * S * 1 * 1 6*128
|
691 |
+
sb,sn,sw,sh = score.size()
|
692 |
+
s_m = score.view(sb, -1).unsqueeze(1)#2*1*M
|
693 |
+
vb,vn,vw,vh = v.size()
|
694 |
+
v_in = v.view(vb, -1).repeat(sb,1,1)#2*M*(c*w*h)
|
695 |
+
mem_out = torch.bmm(s_m, v_in).squeeze(1).view(sb, vn, vw,vh)
|
696 |
+
max_inds = torch.argmax(score, dim=1).squeeze()
|
697 |
+
return mem_out, max_inds
|
698 |
+
|
699 |
+
|
700 |
+
def memorize(self, img, locs):
|
701 |
+
fs = self.E_hq(img, locs)
|
702 |
+
LE256_key, LE256_value = self.LE_256_KV(fs['le256'])
|
703 |
+
RE256_key, RE256_value = self.RE_256_KV(fs['re256'])
|
704 |
+
MO256_key, MO256_value = self.MO_256_KV(fs['mo256'])
|
705 |
+
|
706 |
+
LE128_key, LE128_value = self.LE_128_KV(fs['le128'])
|
707 |
+
RE128_key, RE128_value = self.RE_128_KV(fs['re128'])
|
708 |
+
MO128_key, MO128_value = self.MO_128_KV(fs['mo128'])
|
709 |
+
|
710 |
+
LE64_key, LE64_value = self.LE_64_KV(fs['le64'])
|
711 |
+
RE64_key, RE64_value = self.RE_64_KV(fs['re64'])
|
712 |
+
MO64_key, MO64_value = self.MO_64_KV(fs['mo64'])
|
713 |
+
|
714 |
+
Mem256 = {'LE256Key': LE256_key, 'LE256Value': LE256_value, 'RE256Key': RE256_key, 'RE256Value': RE256_value,'MO256Key': MO256_key, 'MO256Value': MO256_value}
|
715 |
+
Mem128 = {'LE128Key': LE128_key, 'LE128Value': LE128_value, 'RE128Key': RE128_key, 'RE128Value': RE128_value,'MO128Key': MO128_key, 'MO128Value': MO128_value}
|
716 |
+
Mem64 = {'LE64Key': LE64_key, 'LE64Value': LE64_value, 'RE64Key': RE64_key, 'RE64Value': RE64_value,'MO64Key': MO64_key, 'MO64Value': MO64_value}
|
717 |
+
|
718 |
+
FS256 = {'LE256F':fs['le256'], 'RE256F':fs['re256'], 'MO256F':fs['mo256']}
|
719 |
+
FS128 = {'LE128F':fs['le128'], 'RE128F':fs['re128'], 'MO128F':fs['mo128']}
|
720 |
+
FS64 = {'LE64F':fs['le64'], 'RE64F':fs['re64'], 'MO64F':fs['mo64']}
|
721 |
+
|
722 |
+
return Mem256, Mem128, Mem64
|
723 |
+
|
724 |
+
def enhancer(self, fs_in, sp_256=None, sp_128=None, sp_64=None):
|
725 |
+
le_256_q = fs_in['le_256_q']
|
726 |
+
re_256_q = fs_in['re_256_q']
|
727 |
+
mo_256_q = fs_in['mo_256_q']
|
728 |
+
|
729 |
+
le_128_q = fs_in['le_128_q']
|
730 |
+
re_128_q = fs_in['re_128_q']
|
731 |
+
mo_128_q = fs_in['mo_128_q']
|
732 |
+
|
733 |
+
le_64_q = fs_in['le_64_q']
|
734 |
+
re_64_q = fs_in['re_64_q']
|
735 |
+
mo_64_q = fs_in['mo_64_q']
|
736 |
+
|
737 |
+
|
738 |
+
####for 256
|
739 |
+
le_256_mem_g, le_256_inds = self.readMem(self.le_256_mem_key, self.le_256_mem_value, le_256_q)
|
740 |
+
re_256_mem_g, re_256_inds = self.readMem(self.re_256_mem_key, self.re_256_mem_value, re_256_q)
|
741 |
+
mo_256_mem_g, mo_256_inds = self.readMem(self.mo_256_mem_key, self.mo_256_mem_value, mo_256_q)
|
742 |
+
|
743 |
+
le_128_mem_g, le_128_inds = self.readMem(self.le_128_mem_key, self.le_128_mem_value, le_128_q)
|
744 |
+
re_128_mem_g, re_128_inds = self.readMem(self.re_128_mem_key, self.re_128_mem_value, re_128_q)
|
745 |
+
mo_128_mem_g, mo_128_inds = self.readMem(self.mo_128_mem_key, self.mo_128_mem_value, mo_128_q)
|
746 |
+
|
747 |
+
le_64_mem_g, le_64_inds = self.readMem(self.le_64_mem_key, self.le_64_mem_value, le_64_q)
|
748 |
+
re_64_mem_g, re_64_inds = self.readMem(self.re_64_mem_key, self.re_64_mem_value, re_64_q)
|
749 |
+
mo_64_mem_g, mo_64_inds = self.readMem(self.mo_64_mem_key, self.mo_64_mem_value, mo_64_q)
|
750 |
+
|
751 |
+
if sp_256 is not None and sp_128 is not None and sp_64 is not None:
|
752 |
+
le_256_mem_s, _ = self.readMem(sp_256['LE256Key'], sp_256['LE256Value'], le_256_q)
|
753 |
+
re_256_mem_s, _ = self.readMem(sp_256['RE256Key'], sp_256['RE256Value'], re_256_q)
|
754 |
+
mo_256_mem_s, _ = self.readMem(sp_256['MO256Key'], sp_256['MO256Value'], mo_256_q)
|
755 |
+
le_256_mask = self.LE_256_Mask(fs_in['le256'],le_256_mem_s,le_256_mem_g)
|
756 |
+
le_256_mem = le_256_mask*le_256_mem_s + (1-le_256_mask)*le_256_mem_g
|
757 |
+
re_256_mask = self.RE_256_Mask(fs_in['re256'],re_256_mem_s,re_256_mem_g)
|
758 |
+
re_256_mem = re_256_mask*re_256_mem_s + (1-re_256_mask)*re_256_mem_g
|
759 |
+
mo_256_mask = self.MO_256_Mask(fs_in['mo256'],mo_256_mem_s,mo_256_mem_g)
|
760 |
+
mo_256_mem = mo_256_mask*mo_256_mem_s + (1-mo_256_mask)*mo_256_mem_g
|
761 |
+
|
762 |
+
le_128_mem_s, _ = self.readMem(sp_128['LE128Key'], sp_128['LE128Value'], le_128_q)
|
763 |
+
re_128_mem_s, _ = self.readMem(sp_128['RE128Key'], sp_128['RE128Value'], re_128_q)
|
764 |
+
mo_128_mem_s, _ = self.readMem(sp_128['MO128Key'], sp_128['MO128Value'], mo_128_q)
|
765 |
+
le_128_mask = self.LE_128_Mask(fs_in['le128'],le_128_mem_s,le_128_mem_g)
|
766 |
+
le_128_mem = le_128_mask*le_128_mem_s + (1-le_128_mask)*le_128_mem_g
|
767 |
+
re_128_mask = self.RE_128_Mask(fs_in['re128'],re_128_mem_s,re_128_mem_g)
|
768 |
+
re_128_mem = re_128_mask*re_128_mem_s + (1-re_128_mask)*re_128_mem_g
|
769 |
+
mo_128_mask = self.MO_128_Mask(fs_in['mo128'],mo_128_mem_s,mo_128_mem_g)
|
770 |
+
mo_128_mem = mo_128_mask*mo_128_mem_s + (1-mo_128_mask)*mo_128_mem_g
|
771 |
+
|
772 |
+
le_64_mem_s, _ = self.readMem(sp_64['LE64Key'], sp_64['LE64Value'], le_64_q)
|
773 |
+
re_64_mem_s, _ = self.readMem(sp_64['RE64Key'], sp_64['RE64Value'], re_64_q)
|
774 |
+
mo_64_mem_s, _ = self.readMem(sp_64['MO64Key'], sp_64['MO64Value'], mo_64_q)
|
775 |
+
le_64_mask = self.LE_64_Mask(fs_in['le64'],le_64_mem_s,le_64_mem_g)
|
776 |
+
le_64_mem = le_64_mask*le_64_mem_s + (1-le_64_mask)*le_64_mem_g
|
777 |
+
re_64_mask = self.RE_64_Mask(fs_in['re64'],re_64_mem_s,re_64_mem_g)
|
778 |
+
re_64_mem = re_64_mask*re_64_mem_s + (1-re_64_mask)*re_64_mem_g
|
779 |
+
mo_64_mask = self.MO_64_Mask(fs_in['mo64'],mo_64_mem_s,mo_64_mem_g)
|
780 |
+
mo_64_mem = mo_64_mask*mo_64_mem_s + (1-mo_64_mask)*mo_64_mem_g
|
781 |
+
else:
|
782 |
+
le_256_mem = le_256_mem_g
|
783 |
+
re_256_mem = re_256_mem_g
|
784 |
+
mo_256_mem = mo_256_mem_g
|
785 |
+
le_128_mem = le_128_mem_g
|
786 |
+
re_128_mem = re_128_mem_g
|
787 |
+
mo_128_mem = mo_128_mem_g
|
788 |
+
le_64_mem = le_64_mem_g
|
789 |
+
re_64_mem = re_64_mem_g
|
790 |
+
mo_64_mem = mo_64_mem_g
|
791 |
+
|
792 |
+
le_256_mem_norm = adaptive_instance_normalization_4D(le_256_mem, fs_in['le256'])
|
793 |
+
re_256_mem_norm = adaptive_instance_normalization_4D(re_256_mem, fs_in['re256'])
|
794 |
+
mo_256_mem_norm = adaptive_instance_normalization_4D(mo_256_mem, fs_in['mo256'])
|
795 |
+
|
796 |
+
####for 128
|
797 |
+
le_128_mem_norm = adaptive_instance_normalization_4D(le_128_mem, fs_in['le128'])
|
798 |
+
re_128_mem_norm = adaptive_instance_normalization_4D(re_128_mem, fs_in['re128'])
|
799 |
+
mo_128_mem_norm = adaptive_instance_normalization_4D(mo_128_mem, fs_in['mo128'])
|
800 |
+
|
801 |
+
####for 64
|
802 |
+
le_64_mem_norm = adaptive_instance_normalization_4D(le_64_mem, fs_in['le64'])
|
803 |
+
re_64_mem_norm = adaptive_instance_normalization_4D(re_64_mem, fs_in['re64'])
|
804 |
+
mo_64_mem_norm = adaptive_instance_normalization_4D(mo_64_mem, fs_in['mo64'])
|
805 |
+
|
806 |
+
|
807 |
+
EnMem256 = {'LE256Norm': le_256_mem_norm, 'RE256Norm': re_256_mem_norm, 'MO256Norm': mo_256_mem_norm}
|
808 |
+
EnMem128 = {'LE128Norm': le_128_mem_norm, 'RE128Norm': re_128_mem_norm, 'MO128Norm': mo_128_mem_norm}
|
809 |
+
EnMem64 = {'LE64Norm': le_64_mem_norm, 'RE64Norm': re_64_mem_norm, 'MO64Norm': mo_64_mem_norm}
|
810 |
+
Ind256 = {'LE': le_256_inds, 'RE': re_256_inds, 'MO': mo_256_inds}
|
811 |
+
Ind128 = {'LE': le_128_inds, 'RE': re_128_inds, 'MO': mo_128_inds}
|
812 |
+
Ind64 = {'LE': le_64_inds, 'RE': re_64_inds, 'MO': mo_64_inds}
|
813 |
+
return EnMem256, EnMem128, EnMem64, Ind256, Ind128, Ind64
|
814 |
+
|
815 |
+
def reconstruct(self, fs_in, locs, memstar):
|
816 |
+
le_256_mem_norm, re_256_mem_norm, mo_256_mem_norm = memstar[0]['LE256Norm'], memstar[0]['RE256Norm'], memstar[0]['MO256Norm']
|
817 |
+
le_128_mem_norm, re_128_mem_norm, mo_128_mem_norm = memstar[1]['LE128Norm'], memstar[1]['RE128Norm'], memstar[1]['MO128Norm']
|
818 |
+
le_64_mem_norm, re_64_mem_norm, mo_64_mem_norm = memstar[2]['LE64Norm'], memstar[2]['RE64Norm'], memstar[2]['MO64Norm']
|
819 |
+
|
820 |
+
le_256_final = self.LE_256_Attention(le_256_mem_norm - fs_in['le256']) * le_256_mem_norm + fs_in['le256']
|
821 |
+
re_256_final = self.RE_256_Attention(re_256_mem_norm - fs_in['re256']) * re_256_mem_norm + fs_in['re256']
|
822 |
+
mo_256_final = self.MO_256_Attention(mo_256_mem_norm - fs_in['mo256']) * mo_256_mem_norm + fs_in['mo256']
|
823 |
+
|
824 |
+
le_128_final = self.LE_128_Attention(le_128_mem_norm - fs_in['le128']) * le_128_mem_norm + fs_in['le128']
|
825 |
+
re_128_final = self.RE_128_Attention(re_128_mem_norm - fs_in['re128']) * re_128_mem_norm + fs_in['re128']
|
826 |
+
mo_128_final = self.MO_128_Attention(mo_128_mem_norm - fs_in['mo128']) * mo_128_mem_norm + fs_in['mo128']
|
827 |
+
|
828 |
+
le_64_final = self.LE_64_Attention(le_64_mem_norm - fs_in['le64']) * le_64_mem_norm + fs_in['le64']
|
829 |
+
re_64_final = self.RE_64_Attention(re_64_mem_norm - fs_in['re64']) * re_64_mem_norm + fs_in['re64']
|
830 |
+
mo_64_final = self.MO_64_Attention(mo_64_mem_norm - fs_in['mo64']) * mo_64_mem_norm + fs_in['mo64']
|
831 |
+
|
832 |
+
|
833 |
+
le_location = locs[:,0,:]
|
834 |
+
re_location = locs[:,1,:]
|
835 |
+
mo_location = locs[:,3,:]
|
836 |
+
|
837 |
+
# Somehow with latest Torch it doesn't like numpy wrappers anymore
|
838 |
+
|
839 |
+
# le_location = le_location.cpu().int().numpy()
|
840 |
+
# re_location = re_location.cpu().int().numpy()
|
841 |
+
# mo_location = mo_location.cpu().int().numpy()
|
842 |
+
le_location = le_location.cpu().int()
|
843 |
+
re_location = re_location.cpu().int()
|
844 |
+
mo_location = mo_location.cpu().int()
|
845 |
+
|
846 |
+
up_in_256 = fs_in['f256'].clone()# * 0
|
847 |
+
up_in_128 = fs_in['f128'].clone()# * 0
|
848 |
+
up_in_64 = fs_in['f64'].clone()# * 0
|
849 |
+
|
850 |
+
for i in range(fs_in['f256'].size(0)):
|
851 |
+
up_in_256[i:i+1,:,le_location[i,1]//2:le_location[i,3]//2,le_location[i,0]//2:le_location[i,2]//2] = F.interpolate(le_256_final[i:i+1,:,:,:].clone(), (le_location[i,3]//2-le_location[i,1]//2,le_location[i,2]//2-le_location[i,0]//2),mode='bilinear',align_corners=False)
|
852 |
+
up_in_256[i:i+1,:,re_location[i,1]//2:re_location[i,3]//2,re_location[i,0]//2:re_location[i,2]//2] = F.interpolate(re_256_final[i:i+1,:,:,:].clone(), (re_location[i,3]//2-re_location[i,1]//2,re_location[i,2]//2-re_location[i,0]//2),mode='bilinear',align_corners=False)
|
853 |
+
up_in_256[i:i+1,:,mo_location[i,1]//2:mo_location[i,3]//2,mo_location[i,0]//2:mo_location[i,2]//2] = F.interpolate(mo_256_final[i:i+1,:,:,:].clone(), (mo_location[i,3]//2-mo_location[i,1]//2,mo_location[i,2]//2-mo_location[i,0]//2),mode='bilinear',align_corners=False)
|
854 |
+
|
855 |
+
up_in_128[i:i+1,:,le_location[i,1]//4:le_location[i,3]//4,le_location[i,0]//4:le_location[i,2]//4] = F.interpolate(le_128_final[i:i+1,:,:,:].clone(), (le_location[i,3]//4-le_location[i,1]//4,le_location[i,2]//4-le_location[i,0]//4),mode='bilinear',align_corners=False)
|
856 |
+
up_in_128[i:i+1,:,re_location[i,1]//4:re_location[i,3]//4,re_location[i,0]//4:re_location[i,2]//4] = F.interpolate(re_128_final[i:i+1,:,:,:].clone(), (re_location[i,3]//4-re_location[i,1]//4,re_location[i,2]//4-re_location[i,0]//4),mode='bilinear',align_corners=False)
|
857 |
+
up_in_128[i:i+1,:,mo_location[i,1]//4:mo_location[i,3]//4,mo_location[i,0]//4:mo_location[i,2]//4] = F.interpolate(mo_128_final[i:i+1,:,:,:].clone(), (mo_location[i,3]//4-mo_location[i,1]//4,mo_location[i,2]//4-mo_location[i,0]//4),mode='bilinear',align_corners=False)
|
858 |
+
|
859 |
+
up_in_64[i:i+1,:,le_location[i,1]//8:le_location[i,3]//8,le_location[i,0]//8:le_location[i,2]//8] = F.interpolate(le_64_final[i:i+1,:,:,:].clone(), (le_location[i,3]//8-le_location[i,1]//8,le_location[i,2]//8-le_location[i,0]//8),mode='bilinear',align_corners=False)
|
860 |
+
up_in_64[i:i+1,:,re_location[i,1]//8:re_location[i,3]//8,re_location[i,0]//8:re_location[i,2]//8] = F.interpolate(re_64_final[i:i+1,:,:,:].clone(), (re_location[i,3]//8-re_location[i,1]//8,re_location[i,2]//8-re_location[i,0]//8),mode='bilinear',align_corners=False)
|
861 |
+
up_in_64[i:i+1,:,mo_location[i,1]//8:mo_location[i,3]//8,mo_location[i,0]//8:mo_location[i,2]//8] = F.interpolate(mo_64_final[i:i+1,:,:,:].clone(), (mo_location[i,3]//8-mo_location[i,1]//8,mo_location[i,2]//8-mo_location[i,0]//8),mode='bilinear',align_corners=False)
|
862 |
+
|
863 |
+
ms_in_64 = self.MSDilate(fs_in['f64'].clone())
|
864 |
+
fea_up1 = self.up1(ms_in_64, up_in_64)
|
865 |
+
fea_up2 = self.up2(fea_up1, up_in_128) #
|
866 |
+
fea_up3 = self.up3(fea_up2, up_in_256) #
|
867 |
+
output = self.up4(fea_up3) #
|
868 |
+
return output
|
869 |
+
|
870 |
+
def generate_specific_dictionary(self, sp_imgs=None, sp_locs=None):
|
871 |
+
return self.memorize(sp_imgs, sp_locs)
|
872 |
+
|
873 |
+
def forward(self, lq=None, loc=None, sp_256 = None, sp_128 = None, sp_64 = None):
|
874 |
+
try:
|
875 |
+
fs_in = self.E_lq(lq, loc) # low quality images
|
876 |
+
except Exception as e:
|
877 |
+
print(e)
|
878 |
+
|
879 |
+
GeMemNorm256, GeMemNorm128, GeMemNorm64, Ind256, Ind128, Ind64 = self.enhancer(fs_in)
|
880 |
+
GeOut = self.reconstruct(fs_in, loc, memstar = [GeMemNorm256, GeMemNorm128, GeMemNorm64])
|
881 |
+
if sp_256 is not None and sp_128 is not None and sp_64 is not None:
|
882 |
+
GSMemNorm256, GSMemNorm128, GSMemNorm64, _, _, _ = self.enhancer(fs_in, sp_256, sp_128, sp_64)
|
883 |
+
GSOut = self.reconstruct(fs_in, loc, memstar = [GSMemNorm256, GSMemNorm128, GSMemNorm64])
|
884 |
+
else:
|
885 |
+
GSOut = None
|
886 |
+
return GeOut, GSOut
|
887 |
+
|
888 |
+
class UpResBlock(nn.Module):
|
889 |
+
def __init__(self, dim, conv_layer = nn.Conv2d, norm_layer = nn.BatchNorm2d):
|
890 |
+
super(UpResBlock, self).__init__()
|
891 |
+
self.Model = nn.Sequential(
|
892 |
+
SpectralNorm(conv_layer(dim, dim, 3, 1, 1)),
|
893 |
+
nn.LeakyReLU(0.2),
|
894 |
+
SpectralNorm(conv_layer(dim, dim, 3, 1, 1)),
|
895 |
+
)
|
896 |
+
def forward(self, x):
|
897 |
+
out = x + self.Model(x)
|
898 |
+
return out
|
roop/processors/Enhance_GFPGAN.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, List, Callable
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import onnxruntime
|
5 |
+
import roop.globals
|
6 |
+
|
7 |
+
from roop.typing import Face, Frame, FaceSet
|
8 |
+
from roop.utilities import resolve_relative_path
|
9 |
+
|
10 |
+
class Enhance_GFPGAN():
|
11 |
+
plugin_options:dict = None
|
12 |
+
|
13 |
+
model_gfpgan = None
|
14 |
+
name = None
|
15 |
+
devicename = None
|
16 |
+
|
17 |
+
processorname = 'gfpgan'
|
18 |
+
type = 'enhance'
|
19 |
+
|
20 |
+
|
21 |
+
def Initialize(self, plugin_options:dict):
|
22 |
+
if self.plugin_options is not None:
|
23 |
+
if self.plugin_options["devicename"] != plugin_options["devicename"]:
|
24 |
+
self.Release()
|
25 |
+
|
26 |
+
self.plugin_options = plugin_options
|
27 |
+
if self.model_gfpgan is None:
|
28 |
+
model_path = resolve_relative_path('../models/GFPGANv1.4.onnx')
|
29 |
+
self.model_gfpgan = onnxruntime.InferenceSession(model_path, None, providers=roop.globals.execution_providers)
|
30 |
+
# replace Mac mps with cpu for the moment
|
31 |
+
self.devicename = self.plugin_options["devicename"].replace('mps', 'cpu')
|
32 |
+
|
33 |
+
self.name = self.model_gfpgan.get_inputs()[0].name
|
34 |
+
|
35 |
+
def Run(self, source_faceset: FaceSet, target_face: Face, temp_frame: Frame) -> Frame:
|
36 |
+
# preprocess
|
37 |
+
input_size = temp_frame.shape[1]
|
38 |
+
temp_frame = cv2.resize(temp_frame, (512, 512), cv2.INTER_CUBIC)
|
39 |
+
|
40 |
+
temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB)
|
41 |
+
temp_frame = temp_frame.astype('float32') / 255.0
|
42 |
+
temp_frame = (temp_frame - 0.5) / 0.5
|
43 |
+
temp_frame = np.expand_dims(temp_frame, axis=0).transpose(0, 3, 1, 2)
|
44 |
+
|
45 |
+
io_binding = self.model_gfpgan.io_binding()
|
46 |
+
io_binding.bind_cpu_input("input", temp_frame)
|
47 |
+
io_binding.bind_output("1288", self.devicename)
|
48 |
+
self.model_gfpgan.run_with_iobinding(io_binding)
|
49 |
+
ort_outs = io_binding.copy_outputs_to_cpu()
|
50 |
+
result = ort_outs[0][0]
|
51 |
+
|
52 |
+
# post-process
|
53 |
+
result = np.clip(result, -1, 1)
|
54 |
+
result = (result + 1) / 2
|
55 |
+
result = result.transpose(1, 2, 0) * 255.0
|
56 |
+
result = cv2.cvtColor(result, cv2.COLOR_RGB2BGR)
|
57 |
+
scale_factor = int(result.shape[1] / input_size)
|
58 |
+
return result.astype(np.uint8), scale_factor
|
59 |
+
|
60 |
+
|
61 |
+
def Release(self):
|
62 |
+
self.model_gfpgan = None
|
63 |
+
|
64 |
+
|
65 |
+
|
66 |
+
|
67 |
+
|
68 |
+
|
69 |
+
|
70 |
+
|
71 |
+
|
72 |
+
|
73 |
+
|
roop/processors/Enhance_GPEN.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, List, Callable
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import onnxruntime
|
5 |
+
import roop.globals
|
6 |
+
|
7 |
+
from roop.typing import Face, Frame, FaceSet
|
8 |
+
from roop.utilities import resolve_relative_path
|
9 |
+
|
10 |
+
|
11 |
+
class Enhance_GPEN():
|
12 |
+
plugin_options:dict = None
|
13 |
+
|
14 |
+
model_gpen = None
|
15 |
+
name = None
|
16 |
+
devicename = None
|
17 |
+
|
18 |
+
processorname = 'gpen'
|
19 |
+
type = 'enhance'
|
20 |
+
|
21 |
+
|
22 |
+
def Initialize(self, plugin_options:dict):
|
23 |
+
if self.plugin_options is not None:
|
24 |
+
if self.plugin_options["devicename"] != plugin_options["devicename"]:
|
25 |
+
self.Release()
|
26 |
+
|
27 |
+
self.plugin_options = plugin_options
|
28 |
+
if self.model_gpen is None:
|
29 |
+
model_path = resolve_relative_path('../models/GPEN-BFR-512.onnx')
|
30 |
+
self.model_gpen = onnxruntime.InferenceSession(model_path, None, providers=roop.globals.execution_providers)
|
31 |
+
# replace Mac mps with cpu for the moment
|
32 |
+
self.devicename = self.plugin_options["devicename"].replace('mps', 'cpu')
|
33 |
+
|
34 |
+
self.name = self.model_gpen.get_inputs()[0].name
|
35 |
+
|
36 |
+
def Run(self, source_faceset: FaceSet, target_face: Face, temp_frame: Frame) -> Frame:
|
37 |
+
# preprocess
|
38 |
+
input_size = temp_frame.shape[1]
|
39 |
+
temp_frame = cv2.resize(temp_frame, (512, 512), cv2.INTER_CUBIC)
|
40 |
+
|
41 |
+
temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB)
|
42 |
+
temp_frame = temp_frame.astype('float32') / 255.0
|
43 |
+
temp_frame = (temp_frame - 0.5) / 0.5
|
44 |
+
temp_frame = np.expand_dims(temp_frame, axis=0).transpose(0, 3, 1, 2)
|
45 |
+
|
46 |
+
io_binding = self.model_gpen.io_binding()
|
47 |
+
io_binding.bind_cpu_input("input", temp_frame)
|
48 |
+
io_binding.bind_output("output", self.devicename)
|
49 |
+
self.model_gpen.run_with_iobinding(io_binding)
|
50 |
+
ort_outs = io_binding.copy_outputs_to_cpu()
|
51 |
+
result = ort_outs[0][0]
|
52 |
+
|
53 |
+
# post-process
|
54 |
+
result = np.clip(result, -1, 1)
|
55 |
+
result = (result + 1) / 2
|
56 |
+
result = result.transpose(1, 2, 0) * 255.0
|
57 |
+
result = cv2.cvtColor(result, cv2.COLOR_RGB2BGR)
|
58 |
+
scale_factor = int(result.shape[1] / input_size)
|
59 |
+
return result.astype(np.uint8), scale_factor
|
60 |
+
|
61 |
+
|
62 |
+
def Release(self):
|
63 |
+
self.model_gpen = None
|
roop/processors/Enhance_RestoreFormerPPlus.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, List, Callable
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import onnxruntime
|
5 |
+
import roop.globals
|
6 |
+
|
7 |
+
from roop.typing import Face, Frame, FaceSet
|
8 |
+
from roop.utilities import resolve_relative_path
|
9 |
+
|
10 |
+
class Enhance_RestoreFormerPPlus():
|
11 |
+
plugin_options:dict = None
|
12 |
+
model_restoreformerpplus = None
|
13 |
+
devicename = None
|
14 |
+
name = None
|
15 |
+
|
16 |
+
processorname = 'restoreformer++'
|
17 |
+
type = 'enhance'
|
18 |
+
|
19 |
+
|
20 |
+
def Initialize(self, plugin_options:dict):
|
21 |
+
if self.plugin_options is not None:
|
22 |
+
if self.plugin_options["devicename"] != plugin_options["devicename"]:
|
23 |
+
self.Release()
|
24 |
+
|
25 |
+
self.plugin_options = plugin_options
|
26 |
+
if self.model_restoreformerpplus is None:
|
27 |
+
# replace Mac mps with cpu for the moment
|
28 |
+
self.devicename = self.plugin_options["devicename"].replace('mps', 'cpu')
|
29 |
+
model_path = resolve_relative_path('../models/restoreformer_plus_plus.onnx')
|
30 |
+
self.model_restoreformerpplus = onnxruntime.InferenceSession(model_path, None, providers=roop.globals.execution_providers)
|
31 |
+
self.model_inputs = self.model_restoreformerpplus.get_inputs()
|
32 |
+
model_outputs = self.model_restoreformerpplus.get_outputs()
|
33 |
+
self.io_binding = self.model_restoreformerpplus.io_binding()
|
34 |
+
self.io_binding.bind_output(model_outputs[0].name, self.devicename)
|
35 |
+
|
36 |
+
def Run(self, source_faceset: FaceSet, target_face: Face, temp_frame: Frame) -> Frame:
|
37 |
+
# preprocess
|
38 |
+
input_size = temp_frame.shape[1]
|
39 |
+
temp_frame = cv2.resize(temp_frame, (512, 512), cv2.INTER_CUBIC)
|
40 |
+
temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB)
|
41 |
+
temp_frame = temp_frame.astype('float32') / 255.0
|
42 |
+
temp_frame = (temp_frame - 0.5) / 0.5
|
43 |
+
temp_frame = np.expand_dims(temp_frame, axis=0).transpose(0, 3, 1, 2)
|
44 |
+
|
45 |
+
self.io_binding.bind_cpu_input(self.model_inputs[0].name, temp_frame) # .astype(np.float32)
|
46 |
+
self.model_restoreformerpplus.run_with_iobinding(self.io_binding)
|
47 |
+
ort_outs = self.io_binding.copy_outputs_to_cpu()
|
48 |
+
result = ort_outs[0][0]
|
49 |
+
del ort_outs
|
50 |
+
|
51 |
+
result = np.clip(result, -1, 1)
|
52 |
+
result = (result + 1) / 2
|
53 |
+
result = result.transpose(1, 2, 0) * 255.0
|
54 |
+
result = cv2.cvtColor(result, cv2.COLOR_RGB2BGR)
|
55 |
+
scale_factor = int(result.shape[1] / input_size)
|
56 |
+
return result.astype(np.uint8), scale_factor
|
57 |
+
|
58 |
+
|
59 |
+
def Release(self):
|
60 |
+
del self.model_restoreformerpplus
|
61 |
+
self.model_restoreformerpplus = None
|
62 |
+
del self.io_binding
|
63 |
+
self.io_binding = None
|
64 |
+
|
roop/processors/FaceSwapInsightFace.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import roop.globals
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import onnx
|
5 |
+
import onnxruntime
|
6 |
+
|
7 |
+
from roop.typing import Face, Frame
|
8 |
+
from roop.utilities import resolve_relative_path
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
class FaceSwapInsightFace():
|
13 |
+
plugin_options:dict = None
|
14 |
+
model_swap_insightface = None
|
15 |
+
|
16 |
+
processorname = 'faceswap'
|
17 |
+
type = 'swap'
|
18 |
+
|
19 |
+
|
20 |
+
def Initialize(self, plugin_options:dict):
|
21 |
+
if self.plugin_options is not None:
|
22 |
+
if self.plugin_options["devicename"] != plugin_options["devicename"]:
|
23 |
+
self.Release()
|
24 |
+
|
25 |
+
self.plugin_options = plugin_options
|
26 |
+
if self.model_swap_insightface is None:
|
27 |
+
model_path = resolve_relative_path('../models/inswapper_128.onnx')
|
28 |
+
graph = onnx.load(model_path).graph
|
29 |
+
self.emap = onnx.numpy_helper.to_array(graph.initializer[-1])
|
30 |
+
self.devicename = self.plugin_options["devicename"].replace('mps', 'cpu')
|
31 |
+
self.input_mean = 0.0
|
32 |
+
self.input_std = 255.0
|
33 |
+
#cuda_options = {"arena_extend_strategy": "kSameAsRequested", 'cudnn_conv_algo_search': 'DEFAULT'}
|
34 |
+
sess_options = onnxruntime.SessionOptions()
|
35 |
+
sess_options.enable_cpu_mem_arena = False
|
36 |
+
self.model_swap_insightface = onnxruntime.InferenceSession(model_path, sess_options, providers=roop.globals.execution_providers)
|
37 |
+
|
38 |
+
|
39 |
+
|
40 |
+
def Run(self, source_face: Face, target_face: Face, temp_frame: Frame) -> Frame:
|
41 |
+
latent = source_face.normed_embedding.reshape((1,-1))
|
42 |
+
latent = np.dot(latent, self.emap)
|
43 |
+
latent /= np.linalg.norm(latent)
|
44 |
+
io_binding = self.model_swap_insightface.io_binding()
|
45 |
+
io_binding.bind_cpu_input("target", temp_frame)
|
46 |
+
io_binding.bind_cpu_input("source", latent)
|
47 |
+
io_binding.bind_output("output", self.devicename)
|
48 |
+
self.model_swap_insightface.run_with_iobinding(io_binding)
|
49 |
+
ort_outs = io_binding.copy_outputs_to_cpu()[0]
|
50 |
+
return ort_outs[0]
|
51 |
+
|
52 |
+
|
53 |
+
def Release(self):
|
54 |
+
del self.model_swap_insightface
|
55 |
+
self.model_swap_insightface = None
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
|
60 |
+
|
61 |
+
|
roop/processors/Frame_Colorizer.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import onnxruntime
|
4 |
+
import roop.globals
|
5 |
+
|
6 |
+
from roop.utilities import resolve_relative_path
|
7 |
+
from roop.typing import Frame
|
8 |
+
|
9 |
+
class Frame_Colorizer():
|
10 |
+
plugin_options:dict = None
|
11 |
+
model_colorizer = None
|
12 |
+
devicename = None
|
13 |
+
prev_type = None
|
14 |
+
|
15 |
+
processorname = 'deoldify'
|
16 |
+
type = 'frame_colorizer'
|
17 |
+
|
18 |
+
|
19 |
+
def Initialize(self, plugin_options:dict):
|
20 |
+
if self.plugin_options is not None:
|
21 |
+
if self.plugin_options["devicename"] != plugin_options["devicename"]:
|
22 |
+
self.Release()
|
23 |
+
|
24 |
+
self.plugin_options = plugin_options
|
25 |
+
if self.prev_type is not None and self.prev_type != self.plugin_options["subtype"]:
|
26 |
+
self.Release()
|
27 |
+
self.prev_type = self.plugin_options["subtype"]
|
28 |
+
if self.model_colorizer is None:
|
29 |
+
# replace Mac mps with cpu for the moment
|
30 |
+
self.devicename = self.plugin_options["devicename"].replace('mps', 'cpu')
|
31 |
+
if self.prev_type == "deoldify_artistic":
|
32 |
+
model_path = resolve_relative_path('../models/Frame/deoldify_artistic.onnx')
|
33 |
+
elif self.prev_type == "deoldify_stable":
|
34 |
+
model_path = resolve_relative_path('../models/Frame/deoldify_stable.onnx')
|
35 |
+
|
36 |
+
onnxruntime.set_default_logger_severity(3)
|
37 |
+
self.model_colorizer = onnxruntime.InferenceSession(model_path, None, providers=roop.globals.execution_providers)
|
38 |
+
self.model_inputs = self.model_colorizer.get_inputs()
|
39 |
+
model_outputs = self.model_colorizer.get_outputs()
|
40 |
+
self.io_binding = self.model_colorizer.io_binding()
|
41 |
+
self.io_binding.bind_output(model_outputs[0].name, self.devicename)
|
42 |
+
|
43 |
+
def Run(self, input_frame: Frame) -> Frame:
|
44 |
+
temp_frame = cv2.cvtColor(input_frame, cv2.COLOR_BGR2GRAY)
|
45 |
+
temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_GRAY2RGB)
|
46 |
+
temp_frame = cv2.resize(temp_frame, (256, 256))
|
47 |
+
temp_frame = temp_frame.transpose((2, 0, 1))
|
48 |
+
temp_frame = np.expand_dims(temp_frame, axis=0).astype(np.float32)
|
49 |
+
self.io_binding.bind_cpu_input(self.model_inputs[0].name, temp_frame)
|
50 |
+
self.model_colorizer.run_with_iobinding(self.io_binding)
|
51 |
+
ort_outs = self.io_binding.copy_outputs_to_cpu()
|
52 |
+
result = ort_outs[0][0]
|
53 |
+
del ort_outs
|
54 |
+
colorized_frame = result.transpose(1, 2, 0)
|
55 |
+
colorized_frame = cv2.resize(colorized_frame, (input_frame.shape[1], input_frame.shape[0]))
|
56 |
+
temp_blue_channel, _, _ = cv2.split(input_frame)
|
57 |
+
colorized_frame = cv2.cvtColor(colorized_frame, cv2.COLOR_BGR2RGB).astype(np.uint8)
|
58 |
+
colorized_frame = cv2.cvtColor(colorized_frame, cv2.COLOR_BGR2LAB)
|
59 |
+
_, color_green_channel, color_red_channel = cv2.split(colorized_frame)
|
60 |
+
colorized_frame = cv2.merge((temp_blue_channel, color_green_channel, color_red_channel))
|
61 |
+
colorized_frame = cv2.cvtColor(colorized_frame, cv2.COLOR_LAB2BGR)
|
62 |
+
return colorized_frame.astype(np.uint8)
|
63 |
+
|
64 |
+
|
65 |
+
def Release(self):
|
66 |
+
del self.model_colorizer
|
67 |
+
self.model_colorizer = None
|
68 |
+
del self.io_binding
|
69 |
+
self.io_binding = None
|
70 |
+
|
roop/processors/Frame_Filter.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
from roop.typing import Frame
|
5 |
+
|
6 |
+
class Frame_Filter():
|
7 |
+
processorname = 'generic_filter'
|
8 |
+
type = 'frame_processor'
|
9 |
+
|
10 |
+
plugin_options:dict = None
|
11 |
+
|
12 |
+
c64_palette = np.array([
|
13 |
+
[0, 0, 0],
|
14 |
+
[255, 255, 255],
|
15 |
+
[0x81, 0x33, 0x38],
|
16 |
+
[0x75, 0xce, 0xc8],
|
17 |
+
[0x8e, 0x3c, 0x97],
|
18 |
+
[0x56, 0xac, 0x4d],
|
19 |
+
[0x2e, 0x2c, 0x9b],
|
20 |
+
[0xed, 0xf1, 0x71],
|
21 |
+
[0x8e, 0x50, 0x29],
|
22 |
+
[0x55, 0x38, 0x00],
|
23 |
+
[0xc4, 0x6c, 0x71],
|
24 |
+
[0x4a, 0x4a, 0x4a],
|
25 |
+
[0x7b, 0x7b, 0x7b],
|
26 |
+
[0xa9, 0xff, 0x9f],
|
27 |
+
[0x70, 0x6d, 0xeb],
|
28 |
+
[0xb2, 0xb2, 0xb2]
|
29 |
+
])
|
30 |
+
|
31 |
+
|
32 |
+
def RenderC64Screen(self, image):
|
33 |
+
# Simply round the color values to the nearest color in the palette
|
34 |
+
image = cv2.resize(image,(320,200))
|
35 |
+
palette = self.c64_palette / 255.0 # Normalize palette
|
36 |
+
img_normalized = image / 255.0 # Normalize image
|
37 |
+
|
38 |
+
# Calculate the index in the palette that is closest to each pixel in the image
|
39 |
+
indices = np.sqrt(((img_normalized[:, :, None, :] - palette[None, None, :, :]) ** 2).sum(axis=3)).argmin(axis=2)
|
40 |
+
# Map the image to the palette colors
|
41 |
+
mapped_image = palette[indices]
|
42 |
+
return (mapped_image * 255).astype(np.uint8) # Denormalize and return the image
|
43 |
+
|
44 |
+
|
45 |
+
def RenderDetailEnhance(self, image):
|
46 |
+
return cv2.detailEnhance(image)
|
47 |
+
|
48 |
+
def RenderStylize(self, image):
|
49 |
+
return cv2.stylization(image)
|
50 |
+
|
51 |
+
def RenderPencilSketch(self, image):
|
52 |
+
imgray, imout = cv2.pencilSketch(image, sigma_s=60, sigma_r=0.07, shade_factor=0.05)
|
53 |
+
return imout
|
54 |
+
|
55 |
+
def RenderCartoon(self, image):
|
56 |
+
numDownSamples = 2 # number of downscaling steps
|
57 |
+
numBilateralFilters = 7 # number of bilateral filtering steps
|
58 |
+
|
59 |
+
img_color = image
|
60 |
+
for _ in range(numDownSamples):
|
61 |
+
img_color = cv2.pyrDown(img_color)
|
62 |
+
for _ in range(numBilateralFilters):
|
63 |
+
img_color = cv2.bilateralFilter(img_color, 9, 9, 7)
|
64 |
+
for _ in range(numDownSamples):
|
65 |
+
img_color = cv2.pyrUp(img_color)
|
66 |
+
img_gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
|
67 |
+
img_blur = cv2.medianBlur(img_gray, 7)
|
68 |
+
img_edge = cv2.adaptiveThreshold(img_blur, 255,
|
69 |
+
cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 9, 2)
|
70 |
+
img_edge = cv2.cvtColor(img_edge, cv2.COLOR_GRAY2RGB)
|
71 |
+
if img_color.shape != image.shape:
|
72 |
+
img_color = cv2.resize(img_color, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_LINEAR)
|
73 |
+
if img_color.shape != img_edge.shape:
|
74 |
+
img_edge = cv2.resize(img_edge, (img_color.shape[1], img_color.shape[0]), interpolation=cv2.INTER_LINEAR)
|
75 |
+
return cv2.bitwise_and(img_color, img_edge)
|
76 |
+
|
77 |
+
|
78 |
+
def Initialize(self, plugin_options:dict):
|
79 |
+
if self.plugin_options is not None:
|
80 |
+
if self.plugin_options["devicename"] != plugin_options["devicename"]:
|
81 |
+
self.Release()
|
82 |
+
self.plugin_options = plugin_options
|
83 |
+
|
84 |
+
def Run(self, temp_frame: Frame) -> Frame:
|
85 |
+
subtype = self.plugin_options["subtype"]
|
86 |
+
if subtype == "stylize":
|
87 |
+
return self.RenderStylize(temp_frame).astype(np.uint8)
|
88 |
+
if subtype == "detailenhance":
|
89 |
+
return self.RenderDetailEnhance(temp_frame).astype(np.uint8)
|
90 |
+
if subtype == "pencil":
|
91 |
+
return self.RenderPencilSketch(temp_frame).astype(np.uint8)
|
92 |
+
if subtype == "cartoon":
|
93 |
+
return self.RenderCartoon(temp_frame).astype(np.uint8)
|
94 |
+
if subtype == "C64":
|
95 |
+
return self.RenderC64Screen(temp_frame).astype(np.uint8)
|
96 |
+
|
97 |
+
|
98 |
+
def Release(self):
|
99 |
+
pass
|
100 |
+
|
101 |
+
def getProcessedResolution(self, width, height):
|
102 |
+
if self.plugin_options["subtype"] == "C64":
|
103 |
+
return (320,200)
|
104 |
+
return None
|
105 |
+
|
roop/processors/Frame_Masking.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import onnxruntime
|
4 |
+
import roop.globals
|
5 |
+
|
6 |
+
from roop.utilities import resolve_relative_path
|
7 |
+
from roop.typing import Frame
|
8 |
+
|
9 |
+
class Frame_Masking():
|
10 |
+
plugin_options:dict = None
|
11 |
+
model_masking = None
|
12 |
+
devicename = None
|
13 |
+
name = None
|
14 |
+
|
15 |
+
processorname = 'removebg'
|
16 |
+
type = 'frame_masking'
|
17 |
+
|
18 |
+
|
19 |
+
def Initialize(self, plugin_options:dict):
|
20 |
+
if self.plugin_options is not None:
|
21 |
+
if self.plugin_options["devicename"] != plugin_options["devicename"]:
|
22 |
+
self.Release()
|
23 |
+
|
24 |
+
self.plugin_options = plugin_options
|
25 |
+
if self.model_masking is None:
|
26 |
+
# replace Mac mps with cpu for the moment
|
27 |
+
self.devicename = self.plugin_options["devicename"]
|
28 |
+
self.devicename = self.devicename.replace('mps', 'cpu')
|
29 |
+
model_path = resolve_relative_path('../models/Frame/isnet-general-use.onnx')
|
30 |
+
self.model_masking = onnxruntime.InferenceSession(model_path, None, providers=roop.globals.execution_providers)
|
31 |
+
self.model_inputs = self.model_masking.get_inputs()
|
32 |
+
model_outputs = self.model_masking.get_outputs()
|
33 |
+
self.io_binding = self.model_masking.io_binding()
|
34 |
+
self.io_binding.bind_output(model_outputs[0].name, self.devicename)
|
35 |
+
|
36 |
+
def Run(self, temp_frame: Frame) -> Frame:
|
37 |
+
# Pre process:Resize, BGR->RGB, float32 cast
|
38 |
+
input_image = cv2.resize(temp_frame, (1024, 1024))
|
39 |
+
input_image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB)
|
40 |
+
mean = [0.5, 0.5, 0.5]
|
41 |
+
std = [1.0, 1.0, 1.0]
|
42 |
+
input_image = (input_image / 255.0 - mean) / std
|
43 |
+
input_image = input_image.transpose(2, 0, 1)
|
44 |
+
input_image = np.expand_dims(input_image, axis=0)
|
45 |
+
input_image = input_image.astype('float32')
|
46 |
+
|
47 |
+
self.io_binding.bind_cpu_input(self.model_inputs[0].name, input_image)
|
48 |
+
self.model_masking.run_with_iobinding(self.io_binding)
|
49 |
+
ort_outs = self.io_binding.copy_outputs_to_cpu()
|
50 |
+
result = ort_outs[0][0]
|
51 |
+
del ort_outs
|
52 |
+
# Post process:squeeze, Sigmoid, Normarize, uint8 cast
|
53 |
+
mask = np.squeeze(result[0])
|
54 |
+
min_value = np.min(mask)
|
55 |
+
max_value = np.max(mask)
|
56 |
+
mask = (mask - min_value) / (max_value - min_value)
|
57 |
+
#mask = np.where(mask < score_th, 0, 1)
|
58 |
+
#mask *= 255
|
59 |
+
mask = cv2.resize(mask, (temp_frame.shape[1], temp_frame.shape[0]), interpolation=cv2.INTER_LINEAR)
|
60 |
+
mask = np.reshape(mask, [mask.shape[0],mask.shape[1],1])
|
61 |
+
result = mask * temp_frame.astype(np.float32)
|
62 |
+
return result.astype(np.uint8)
|
63 |
+
|
64 |
+
|
65 |
+
|
66 |
+
def Release(self):
|
67 |
+
del self.model_masking
|
68 |
+
self.model_masking = None
|
69 |
+
del self.io_binding
|
70 |
+
self.io_binding = None
|
71 |
+
|
roop/processors/Frame_Upscale.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import onnxruntime
|
4 |
+
import roop.globals
|
5 |
+
|
6 |
+
from roop.utilities import resolve_relative_path, conditional_thread_semaphore
|
7 |
+
from roop.typing import Frame
|
8 |
+
|
9 |
+
|
10 |
+
class Frame_Upscale():
|
11 |
+
plugin_options:dict = None
|
12 |
+
model_upscale = None
|
13 |
+
devicename = None
|
14 |
+
prev_type = None
|
15 |
+
|
16 |
+
processorname = 'upscale'
|
17 |
+
type = 'frame_enhancer'
|
18 |
+
|
19 |
+
|
20 |
+
def Initialize(self, plugin_options:dict):
|
21 |
+
if self.plugin_options is not None:
|
22 |
+
if self.plugin_options["devicename"] != plugin_options["devicename"]:
|
23 |
+
self.Release()
|
24 |
+
|
25 |
+
self.plugin_options = plugin_options
|
26 |
+
if self.prev_type is not None and self.prev_type != self.plugin_options["subtype"]:
|
27 |
+
self.Release()
|
28 |
+
self.prev_type = self.plugin_options["subtype"]
|
29 |
+
if self.model_upscale is None:
|
30 |
+
# replace Mac mps with cpu for the moment
|
31 |
+
self.devicename = self.plugin_options["devicename"].replace('mps', 'cpu')
|
32 |
+
if self.prev_type == "esrganx4":
|
33 |
+
model_path = resolve_relative_path('../models/Frame/real_esrgan_x4.onnx')
|
34 |
+
self.scale = 4
|
35 |
+
elif self.prev_type == "esrganx2":
|
36 |
+
model_path = resolve_relative_path('../models/Frame/real_esrgan_x2.onnx')
|
37 |
+
self.scale = 2
|
38 |
+
elif self.prev_type == "lsdirx4":
|
39 |
+
model_path = resolve_relative_path('../models/Frame/lsdir_x4.onnx')
|
40 |
+
self.scale = 4
|
41 |
+
onnxruntime.set_default_logger_severity(3)
|
42 |
+
self.model_upscale = onnxruntime.InferenceSession(model_path, None, providers=roop.globals.execution_providers)
|
43 |
+
self.model_inputs = self.model_upscale.get_inputs()
|
44 |
+
model_outputs = self.model_upscale.get_outputs()
|
45 |
+
self.io_binding = self.model_upscale.io_binding()
|
46 |
+
self.io_binding.bind_output(model_outputs[0].name, self.devicename)
|
47 |
+
|
48 |
+
def getProcessedResolution(self, width, height):
|
49 |
+
return (width * self.scale, height * self.scale)
|
50 |
+
|
51 |
+
# borrowed from facefusion -> https://github.com/facefusion/facefusion
|
52 |
+
def prepare_tile_frame(self, tile_frame : Frame) -> Frame:
|
53 |
+
tile_frame = np.expand_dims(tile_frame[:, :, ::-1], axis = 0)
|
54 |
+
tile_frame = tile_frame.transpose(0, 3, 1, 2)
|
55 |
+
tile_frame = tile_frame.astype(np.float32) / 255
|
56 |
+
return tile_frame
|
57 |
+
|
58 |
+
|
59 |
+
def normalize_tile_frame(self, tile_frame : Frame) -> Frame:
|
60 |
+
tile_frame = tile_frame.transpose(0, 2, 3, 1).squeeze(0) * 255
|
61 |
+
tile_frame = tile_frame.clip(0, 255).astype(np.uint8)[:, :, ::-1]
|
62 |
+
return tile_frame
|
63 |
+
|
64 |
+
def create_tile_frames(self, input_frame : Frame, size):
|
65 |
+
input_frame = np.pad(input_frame, ((size[1], size[1]), (size[1], size[1]), (0, 0)))
|
66 |
+
tile_width = size[0] - 2 * size[2]
|
67 |
+
pad_size_bottom = size[2] + tile_width - input_frame.shape[0] % tile_width
|
68 |
+
pad_size_right = size[2] + tile_width - input_frame.shape[1] % tile_width
|
69 |
+
pad_vision_frame = np.pad(input_frame, ((size[2], pad_size_bottom), (size[2], pad_size_right), (0, 0)))
|
70 |
+
pad_height, pad_width = pad_vision_frame.shape[:2]
|
71 |
+
row_range = range(size[2], pad_height - size[2], tile_width)
|
72 |
+
col_range = range(size[2], pad_width - size[2], tile_width)
|
73 |
+
tile_frames = []
|
74 |
+
|
75 |
+
for row_frame in row_range:
|
76 |
+
top = row_frame - size[2]
|
77 |
+
bottom = row_frame + size[2] + tile_width
|
78 |
+
for column_vision_frame in col_range:
|
79 |
+
left = column_vision_frame - size[2]
|
80 |
+
right = column_vision_frame + size[2] + tile_width
|
81 |
+
tile_frames.append(pad_vision_frame[top:bottom, left:right, :])
|
82 |
+
return tile_frames, pad_width, pad_height
|
83 |
+
|
84 |
+
|
85 |
+
def merge_tile_frames(self, tile_frames, temp_width : int, temp_height : int, pad_width : int, pad_height : int, size) -> Frame:
|
86 |
+
merge_frame = np.zeros((pad_height, pad_width, 3)).astype(np.uint8)
|
87 |
+
tile_width = tile_frames[0].shape[1] - 2 * size[2]
|
88 |
+
tiles_per_row = min(pad_width // tile_width, len(tile_frames))
|
89 |
+
|
90 |
+
for index, tile_frame in enumerate(tile_frames):
|
91 |
+
tile_frame = tile_frame[size[2]:-size[2], size[2]:-size[2]]
|
92 |
+
row_index = index // tiles_per_row
|
93 |
+
col_index = index % tiles_per_row
|
94 |
+
top = row_index * tile_frame.shape[0]
|
95 |
+
bottom = top + tile_frame.shape[0]
|
96 |
+
left = col_index * tile_frame.shape[1]
|
97 |
+
right = left + tile_frame.shape[1]
|
98 |
+
merge_frame[top:bottom, left:right, :] = tile_frame
|
99 |
+
merge_frame = merge_frame[size[1] : size[1] + temp_height, size[1]: size[1] + temp_width, :]
|
100 |
+
return merge_frame
|
101 |
+
|
102 |
+
|
103 |
+
def Run(self, temp_frame: Frame) -> Frame:
|
104 |
+
size = (128, 8, 2)
|
105 |
+
temp_height, temp_width = temp_frame.shape[:2]
|
106 |
+
upscale_tile_frames, pad_width, pad_height = self.create_tile_frames(temp_frame, size)
|
107 |
+
|
108 |
+
for index, tile_frame in enumerate(upscale_tile_frames):
|
109 |
+
tile_frame = self.prepare_tile_frame(tile_frame)
|
110 |
+
with conditional_thread_semaphore():
|
111 |
+
self.io_binding.bind_cpu_input(self.model_inputs[0].name, tile_frame)
|
112 |
+
self.model_upscale.run_with_iobinding(self.io_binding)
|
113 |
+
ort_outs = self.io_binding.copy_outputs_to_cpu()
|
114 |
+
result = ort_outs[0]
|
115 |
+
upscale_tile_frames[index] = self.normalize_tile_frame(result)
|
116 |
+
final_frame = self.merge_tile_frames(upscale_tile_frames, temp_width * self.scale
|
117 |
+
, temp_height * self.scale
|
118 |
+
, pad_width * self.scale, pad_height * self.scale
|
119 |
+
, (size[0] * self.scale, size[1] * self.scale, size[2] * self.scale))
|
120 |
+
return final_frame.astype(np.uint8)
|
121 |
+
|
122 |
+
|
123 |
+
|
124 |
+
def Release(self):
|
125 |
+
del self.model_upscale
|
126 |
+
self.model_upscale = None
|
127 |
+
del self.io_binding
|
128 |
+
self.io_binding = None
|
129 |
+
|
roop/processors/Mask_Clip2Seg.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import threading
|
5 |
+
from torchvision import transforms
|
6 |
+
from clip.clipseg import CLIPDensePredT
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from roop.typing import Frame
|
10 |
+
|
11 |
+
THREAD_LOCK_CLIP = threading.Lock()
|
12 |
+
|
13 |
+
|
14 |
+
class Mask_Clip2Seg():
|
15 |
+
plugin_options:dict = None
|
16 |
+
model_clip = None
|
17 |
+
|
18 |
+
processorname = 'clip2seg'
|
19 |
+
type = 'mask'
|
20 |
+
|
21 |
+
|
22 |
+
def Initialize(self, plugin_options:dict):
|
23 |
+
if self.plugin_options is not None:
|
24 |
+
if self.plugin_options["devicename"] != plugin_options["devicename"]:
|
25 |
+
self.Release()
|
26 |
+
|
27 |
+
self.plugin_options = plugin_options
|
28 |
+
if self.model_clip is None:
|
29 |
+
self.model_clip = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, complex_trans_conv=True)
|
30 |
+
self.model_clip.eval();
|
31 |
+
self.model_clip.load_state_dict(torch.load('models/CLIP/rd64-uni-refined.pth', map_location=torch.device('cpu')), strict=False)
|
32 |
+
|
33 |
+
device = torch.device(self.plugin_options["devicename"])
|
34 |
+
self.model_clip.to(device)
|
35 |
+
|
36 |
+
|
37 |
+
def Run(self, img1, keywords:str) -> Frame:
|
38 |
+
if keywords is None or len(keywords) < 1 or img1 is None:
|
39 |
+
return img1
|
40 |
+
|
41 |
+
source_image_small = cv2.resize(img1, (256,256))
|
42 |
+
|
43 |
+
img_mask = np.full((source_image_small.shape[0],source_image_small.shape[1]), 0, dtype=np.float32)
|
44 |
+
mask_border = 1
|
45 |
+
l = 0
|
46 |
+
t = 0
|
47 |
+
r = 1
|
48 |
+
b = 1
|
49 |
+
|
50 |
+
mask_blur = 5
|
51 |
+
clip_blur = 5
|
52 |
+
|
53 |
+
img_mask = cv2.rectangle(img_mask, (mask_border+int(l), mask_border+int(t)),
|
54 |
+
(256 - mask_border-int(r), 256-mask_border-int(b)), (255, 255, 255), -1)
|
55 |
+
img_mask = cv2.GaussianBlur(img_mask, (mask_blur*2+1,mask_blur*2+1), 0)
|
56 |
+
img_mask /= 255
|
57 |
+
|
58 |
+
|
59 |
+
input_image = source_image_small
|
60 |
+
|
61 |
+
transform = transforms.Compose([
|
62 |
+
transforms.ToTensor(),
|
63 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
64 |
+
transforms.Resize((256, 256)),
|
65 |
+
])
|
66 |
+
img = transform(input_image).unsqueeze(0)
|
67 |
+
|
68 |
+
thresh = 0.5
|
69 |
+
prompts = keywords.split(',')
|
70 |
+
with THREAD_LOCK_CLIP:
|
71 |
+
with torch.no_grad():
|
72 |
+
preds = self.model_clip(img.repeat(len(prompts),1,1,1), prompts)[0]
|
73 |
+
clip_mask = torch.sigmoid(preds[0][0])
|
74 |
+
for i in range(len(prompts)-1):
|
75 |
+
clip_mask += torch.sigmoid(preds[i+1][0])
|
76 |
+
|
77 |
+
clip_mask = clip_mask.data.cpu().numpy()
|
78 |
+
np.clip(clip_mask, 0, 1)
|
79 |
+
|
80 |
+
clip_mask[clip_mask>thresh] = 1.0
|
81 |
+
clip_mask[clip_mask<=thresh] = 0.0
|
82 |
+
kernel = np.ones((5, 5), np.float32)
|
83 |
+
clip_mask = cv2.dilate(clip_mask, kernel, iterations=1)
|
84 |
+
clip_mask = cv2.GaussianBlur(clip_mask, (clip_blur*2+1,clip_blur*2+1), 0)
|
85 |
+
|
86 |
+
img_mask *= clip_mask
|
87 |
+
img_mask[img_mask<0.0] = 0.0
|
88 |
+
return img_mask
|
89 |
+
|
90 |
+
|
91 |
+
|
92 |
+
def Release(self):
|
93 |
+
self.model_clip = None
|
94 |
+
|
roop/processors/Mask_XSeg.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
import onnxruntime
|
4 |
+
import roop.globals
|
5 |
+
|
6 |
+
from roop.typing import Frame
|
7 |
+
from roop.utilities import resolve_relative_path, conditional_thread_semaphore
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
class Mask_XSeg():
|
12 |
+
plugin_options:dict = None
|
13 |
+
|
14 |
+
model_xseg = None
|
15 |
+
|
16 |
+
processorname = 'mask_xseg'
|
17 |
+
type = 'mask'
|
18 |
+
|
19 |
+
|
20 |
+
def Initialize(self, plugin_options:dict):
|
21 |
+
if self.plugin_options is not None:
|
22 |
+
if self.plugin_options["devicename"] != plugin_options["devicename"]:
|
23 |
+
self.Release()
|
24 |
+
|
25 |
+
self.plugin_options = plugin_options
|
26 |
+
if self.model_xseg is None:
|
27 |
+
model_path = resolve_relative_path('../models/xseg.onnx')
|
28 |
+
onnxruntime.set_default_logger_severity(3)
|
29 |
+
self.model_xseg = onnxruntime.InferenceSession(model_path, None, providers=roop.globals.execution_providers)
|
30 |
+
self.model_inputs = self.model_xseg.get_inputs()
|
31 |
+
self.model_outputs = self.model_xseg.get_outputs()
|
32 |
+
|
33 |
+
# replace Mac mps with cpu for the moment
|
34 |
+
self.devicename = self.plugin_options["devicename"].replace('mps', 'cpu')
|
35 |
+
|
36 |
+
|
37 |
+
def Run(self, img1, keywords:str) -> Frame:
|
38 |
+
temp_frame = cv2.resize(img1, (256, 256), cv2.INTER_CUBIC)
|
39 |
+
temp_frame = temp_frame.astype('float32') / 255.0
|
40 |
+
temp_frame = temp_frame[None, ...]
|
41 |
+
io_binding = self.model_xseg.io_binding()
|
42 |
+
io_binding.bind_cpu_input(self.model_inputs[0].name, temp_frame)
|
43 |
+
io_binding.bind_output(self.model_outputs[0].name, self.devicename)
|
44 |
+
self.model_xseg.run_with_iobinding(io_binding)
|
45 |
+
ort_outs = io_binding.copy_outputs_to_cpu()
|
46 |
+
result = ort_outs[0][0]
|
47 |
+
result = np.clip(result, 0, 1.0)
|
48 |
+
result[result < 0.1] = 0
|
49 |
+
# invert values to mask areas to keep
|
50 |
+
result = 1.0 - result
|
51 |
+
return result
|
52 |
+
|
53 |
+
|
54 |
+
def Release(self):
|
55 |
+
del self.model_xseg
|
56 |
+
self.model_xseg = None
|
57 |
+
|
58 |
+
|