Spaces:
Runtime error
Runtime error
initial commit
Browse files- .gitignore +11 -0
- conr.py +292 -0
- data_loader.py +273 -0
- infer.sh +14 -0
- model/__init__.py +1 -0
- model/backbone.py +285 -0
- model/decoder_small.py +43 -0
- model/shader.py +290 -0
- model/warplayer.py +56 -0
- streamlit.py +52 -0
- train.py +229 -0
.gitignore
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
results/
|
| 2 |
+
test_data/
|
| 3 |
+
test_data_pre/
|
| 4 |
+
weights/
|
| 5 |
+
x264/
|
| 6 |
+
*.mp3
|
| 7 |
+
*.mp4
|
| 8 |
+
*.txt
|
| 9 |
+
*.png
|
| 10 |
+
complex_infer.sh
|
| 11 |
+
__pycache__/
|
conr.py
ADDED
|
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from model.backbone import ResEncUnet
|
| 6 |
+
|
| 7 |
+
from model.shader import CINN
|
| 8 |
+
from model.decoder_small import RGBADecoderNet
|
| 9 |
+
|
| 10 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def UDPClip(x):
|
| 14 |
+
return torch.clamp(x, min=0, max=1) # NCHW
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class CoNR():
|
| 18 |
+
def __init__(self, args):
|
| 19 |
+
self.args = args
|
| 20 |
+
|
| 21 |
+
self.udpparsernet = ResEncUnet(
|
| 22 |
+
backbone_name='resnet50_danbo',
|
| 23 |
+
classes=4,
|
| 24 |
+
pretrained=(args.local_rank == 0),
|
| 25 |
+
parametric_upsampling=True,
|
| 26 |
+
decoder_filters=(512, 384, 256, 128, 32),
|
| 27 |
+
map_location=device
|
| 28 |
+
)
|
| 29 |
+
self.target_pose_encoder = ResEncUnet(
|
| 30 |
+
backbone_name='resnet18_danbo-4',
|
| 31 |
+
classes=1,
|
| 32 |
+
pretrained=(args.local_rank == 0),
|
| 33 |
+
parametric_upsampling=True,
|
| 34 |
+
decoder_filters=(512, 384, 256, 128, 32),
|
| 35 |
+
map_location=device
|
| 36 |
+
)
|
| 37 |
+
self.DIM_SHADER_REFERENCE = 4
|
| 38 |
+
self.shader = CINN(self.DIM_SHADER_REFERENCE)
|
| 39 |
+
self.rgbadecodernet = RGBADecoderNet(
|
| 40 |
+
)
|
| 41 |
+
self.device()
|
| 42 |
+
self.parser_ckpt = None
|
| 43 |
+
|
| 44 |
+
def dist(self):
|
| 45 |
+
args = self.args
|
| 46 |
+
if args.distributed:
|
| 47 |
+
self.udpparsernet = torch.nn.parallel.DistributedDataParallel(
|
| 48 |
+
self.udpparsernet,
|
| 49 |
+
device_ids=[
|
| 50 |
+
args.local_rank],
|
| 51 |
+
output_device=args.local_rank,
|
| 52 |
+
broadcast_buffers=False,
|
| 53 |
+
find_unused_parameters=True
|
| 54 |
+
)
|
| 55 |
+
self.target_pose_encoder = torch.nn.parallel.DistributedDataParallel(
|
| 56 |
+
self.target_pose_encoder,
|
| 57 |
+
device_ids=[
|
| 58 |
+
args.local_rank],
|
| 59 |
+
output_device=args.local_rank,
|
| 60 |
+
broadcast_buffers=False,
|
| 61 |
+
find_unused_parameters=True
|
| 62 |
+
)
|
| 63 |
+
self.shader = torch.nn.parallel.DistributedDataParallel(
|
| 64 |
+
self.shader,
|
| 65 |
+
device_ids=[
|
| 66 |
+
args.local_rank],
|
| 67 |
+
output_device=args.local_rank,
|
| 68 |
+
broadcast_buffers=True
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
self.rgbadecodernet = torch.nn.parallel.DistributedDataParallel(
|
| 72 |
+
self.rgbadecodernet,
|
| 73 |
+
device_ids=[
|
| 74 |
+
args.local_rank],
|
| 75 |
+
output_device=args.local_rank,
|
| 76 |
+
broadcast_buffers=True
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
def load_model(self, path):
|
| 80 |
+
self.udpparsernet.load_state_dict(
|
| 81 |
+
torch.load('{}/udpparsernet.pth'.format(path), map_location=device))
|
| 82 |
+
self.target_pose_encoder.load_state_dict(
|
| 83 |
+
torch.load('{}/target_pose_encoder.pth'.format(path), map_location=device))
|
| 84 |
+
self.shader.load_state_dict(
|
| 85 |
+
torch.load('{}/shader.pth'.format(path), map_location=device))
|
| 86 |
+
self.rgbadecodernet.load_state_dict(
|
| 87 |
+
torch.load('{}/rgbadecodernet.pth'.format(path), map_location=device))
|
| 88 |
+
|
| 89 |
+
def save_model(self, ite_num):
|
| 90 |
+
self._save_pth(self.udpparsernet,
|
| 91 |
+
model_name="udpparsernet", ite_num=ite_num)
|
| 92 |
+
self._save_pth(self.target_pose_encoder,
|
| 93 |
+
model_name="target_pose_encoder", ite_num=ite_num)
|
| 94 |
+
self._save_pth(self.shader,
|
| 95 |
+
model_name="shader", ite_num=ite_num)
|
| 96 |
+
self._save_pth(self.rgbadecodernet,
|
| 97 |
+
model_name="rgbadecodernet", ite_num=ite_num)
|
| 98 |
+
|
| 99 |
+
def _save_pth(self, net, model_name, ite_num):
|
| 100 |
+
args = self.args
|
| 101 |
+
to_save = None
|
| 102 |
+
if args.distributed:
|
| 103 |
+
if args.local_rank == 0:
|
| 104 |
+
to_save = net.module.state_dict()
|
| 105 |
+
else:
|
| 106 |
+
to_save = net.state_dict()
|
| 107 |
+
if to_save:
|
| 108 |
+
model_dir = os.path.join(
|
| 109 |
+
os.getcwd(), 'saved_models', args.model_name + os.sep + "checkpoints" + os.sep + "itr_%d" % (ite_num)+os.sep)
|
| 110 |
+
|
| 111 |
+
os.makedirs(model_dir, exist_ok=True)
|
| 112 |
+
torch.save(to_save, model_dir + model_name + ".pth")
|
| 113 |
+
|
| 114 |
+
def train(self):
|
| 115 |
+
self.udpparsernet.train()
|
| 116 |
+
self.target_pose_encoder.train()
|
| 117 |
+
self.shader.train()
|
| 118 |
+
self.rgbadecodernet.train()
|
| 119 |
+
|
| 120 |
+
def eval(self):
|
| 121 |
+
self.udpparsernet.eval()
|
| 122 |
+
self.target_pose_encoder.eval()
|
| 123 |
+
self.shader.eval()
|
| 124 |
+
self.rgbadecodernet.eval()
|
| 125 |
+
|
| 126 |
+
def device(self):
|
| 127 |
+
self.udpparsernet.to(device)
|
| 128 |
+
self.target_pose_encoder.to(device)
|
| 129 |
+
self.shader.to(device)
|
| 130 |
+
self.rgbadecodernet.to(device)
|
| 131 |
+
|
| 132 |
+
def data_norm_image(self, data):
|
| 133 |
+
|
| 134 |
+
with torch.cuda.amp.autocast(enabled=False):
|
| 135 |
+
for name in ["character_labels", "pose_label"]:
|
| 136 |
+
if name in data:
|
| 137 |
+
data[name] = data[name].to(
|
| 138 |
+
device, non_blocking=True).float()
|
| 139 |
+
for name in ["pose_images", "pose_mask", "character_images", "character_masks"]:
|
| 140 |
+
if name in data:
|
| 141 |
+
data[name] = data[name].to(
|
| 142 |
+
device, non_blocking=True).float() / 255.0
|
| 143 |
+
if "pose_images" in data:
|
| 144 |
+
data["num_pose_images"] = data["pose_images"].shape[1]
|
| 145 |
+
data["num_samples"] = data["pose_images"].shape[0]
|
| 146 |
+
if "character_images" in data:
|
| 147 |
+
data["num_character_images"] = data["character_images"].shape[1]
|
| 148 |
+
data["num_samples"] = data["character_images"].shape[0]
|
| 149 |
+
if "pose_images" in data and "character_images" in data:
|
| 150 |
+
assert (data["pose_images"].shape[0] ==
|
| 151 |
+
data["character_images"].shape[0])
|
| 152 |
+
return data
|
| 153 |
+
|
| 154 |
+
def reset_charactersheet(self):
|
| 155 |
+
self.parser_ckpt = None
|
| 156 |
+
|
| 157 |
+
def model_step(self, data, training=False):
|
| 158 |
+
self.eval()
|
| 159 |
+
with torch.cuda.amp.autocast(enabled=False):
|
| 160 |
+
pred = {}
|
| 161 |
+
if self.parser_ckpt:
|
| 162 |
+
pred["parser"] = self.parser_ckpt
|
| 163 |
+
else:
|
| 164 |
+
pred = self.character_parser_forward(data, pred)
|
| 165 |
+
self.parser_ckpt = pred["parser"]
|
| 166 |
+
pred = self.pose_parser_sc_forward(data, pred)
|
| 167 |
+
pred = self.shader_pose_encoder_forward(data, pred)
|
| 168 |
+
pred = self.shader_forward(data, pred)
|
| 169 |
+
return pred
|
| 170 |
+
|
| 171 |
+
def shader_forward(self, data, pred={}):
|
| 172 |
+
assert ("num_character_images" in data), "ERROR: No Character Sheet input."
|
| 173 |
+
|
| 174 |
+
character_images_rgb_nmchw, num_character_images = data[
|
| 175 |
+
"character_images"], data["num_character_images"]
|
| 176 |
+
# build x_reference_rgb_a_sudp in the draw call
|
| 177 |
+
shader_character_a_nmchw = data["character_masks"]
|
| 178 |
+
assert torch.any(torch.mean(shader_character_a_nmchw, (0, 2, 3, 4)) >= 0.95) == False, "ERROR: \
|
| 179 |
+
No transparent area found in the image, PLEASE separate the foreground of input character sheets.\
|
| 180 |
+
The website waifucutout.com is recommended to automatically cut out the foreground."
|
| 181 |
+
|
| 182 |
+
if shader_character_a_nmchw is None:
|
| 183 |
+
shader_character_a_nmchw = pred["parser"]["pred"][:, :, 3:4, :, :]
|
| 184 |
+
x_reference_rgb_a = torch.cat([shader_character_a_nmchw[:, :, :, :, :] * character_images_rgb_nmchw[:, :, :, :, :],
|
| 185 |
+
shader_character_a_nmchw[:,
|
| 186 |
+
:, :, :, :],
|
| 187 |
+
|
| 188 |
+
], 2)
|
| 189 |
+
assert (x_reference_rgb_a.shape[2] == self.DIM_SHADER_REFERENCE)
|
| 190 |
+
# build x_reference_features in the draw call
|
| 191 |
+
x_reference_features = pred["parser"]["features"]
|
| 192 |
+
# run cinn shader
|
| 193 |
+
retdic = self.shader(
|
| 194 |
+
pred["shader"]["target_pose_features"], x_reference_rgb_a, x_reference_features)
|
| 195 |
+
pred["shader"].update(retdic)
|
| 196 |
+
|
| 197 |
+
# decode rgba
|
| 198 |
+
if True:
|
| 199 |
+
dec_out = self.rgbadecodernet(
|
| 200 |
+
retdic["y_last_remote_features"])
|
| 201 |
+
y_weighted_x_reference_RGB = dec_out[:, 0:3, :, :]
|
| 202 |
+
y_weighted_mask_A = dec_out[:, 3:4, :, :]
|
| 203 |
+
y_weighted_warp_decoded_rgba = torch.cat(
|
| 204 |
+
(y_weighted_x_reference_RGB*y_weighted_mask_A, y_weighted_mask_A), dim=1
|
| 205 |
+
)
|
| 206 |
+
assert(y_weighted_warp_decoded_rgba.shape[1] == 4)
|
| 207 |
+
assert(
|
| 208 |
+
y_weighted_warp_decoded_rgba.shape[-1] == character_images_rgb_nmchw.shape[-1])
|
| 209 |
+
# apply decoded mask to decoded rgb, finishing the draw call
|
| 210 |
+
pred["shader"]["y_weighted_warp_decoded_rgba"] = y_weighted_warp_decoded_rgba
|
| 211 |
+
return pred
|
| 212 |
+
|
| 213 |
+
def character_parser_forward(self, data, pred={}):
|
| 214 |
+
if not("num_character_images" in data and "character_images" in data):
|
| 215 |
+
return pred
|
| 216 |
+
pred["parser"] = {"pred": None} # create output
|
| 217 |
+
|
| 218 |
+
inputs_rgb_nmchw, num_samples, num_character_images = data[
|
| 219 |
+
"character_images"], data["num_samples"], data["num_character_images"]
|
| 220 |
+
inputs_rgb_fchw = inputs_rgb_nmchw.view(
|
| 221 |
+
(num_samples * num_character_images, inputs_rgb_nmchw.shape[2], inputs_rgb_nmchw.shape[3], inputs_rgb_nmchw.shape[4]))
|
| 222 |
+
|
| 223 |
+
encoder_out, features = self.udpparsernet(
|
| 224 |
+
(inputs_rgb_fchw-0.6)/0.2970)
|
| 225 |
+
|
| 226 |
+
pred["parser"]["features"] = [features_out.view(
|
| 227 |
+
(num_samples, num_character_images, features_out.shape[1], features_out.shape[2], features_out.shape[3])) for features_out in features]
|
| 228 |
+
|
| 229 |
+
if (encoder_out is not None):
|
| 230 |
+
|
| 231 |
+
pred["parser"]["pred"] = UDPClip(encoder_out.view(
|
| 232 |
+
(num_samples, num_character_images, encoder_out.shape[1], encoder_out.shape[2], encoder_out.shape[3])))
|
| 233 |
+
|
| 234 |
+
return pred
|
| 235 |
+
|
| 236 |
+
def pose_parser_sc_forward(self, data, pred={}):
|
| 237 |
+
if not("num_pose_images" in data and "pose_images" in data):
|
| 238 |
+
return pred
|
| 239 |
+
inputs_aug_rgb_nmchw, num_samples, num_pose_images = data[
|
| 240 |
+
"pose_images"], data["num_samples"], data["num_pose_images"]
|
| 241 |
+
inputs_aug_rgb_fchw = inputs_aug_rgb_nmchw.view(
|
| 242 |
+
(num_samples * num_pose_images, inputs_aug_rgb_nmchw.shape[2], inputs_aug_rgb_nmchw.shape[3], inputs_aug_rgb_nmchw.shape[4]))
|
| 243 |
+
|
| 244 |
+
encoder_out, _ = self.udpparsernet(
|
| 245 |
+
(inputs_aug_rgb_fchw-0.6)/0.2970)
|
| 246 |
+
|
| 247 |
+
encoder_out = encoder_out.view(
|
| 248 |
+
(num_samples, num_pose_images, encoder_out.shape[1], encoder_out.shape[2], encoder_out.shape[3]))
|
| 249 |
+
|
| 250 |
+
# apply sigmoid after eval loss
|
| 251 |
+
pred["pose_parser"] = {"pred":UDPClip(encoder_out)}
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
return pred
|
| 255 |
+
|
| 256 |
+
def shader_pose_encoder_forward(self, data, pred={}):
|
| 257 |
+
pred["shader"] = {} # create output
|
| 258 |
+
if "pose_images" in data:
|
| 259 |
+
pose_images_rgb_nmchw = data["pose_images"]
|
| 260 |
+
target_gt_rgb = pose_images_rgb_nmchw[:, 0, :, :, :]
|
| 261 |
+
pred["shader"]["target_gt_rgb"] = target_gt_rgb
|
| 262 |
+
|
| 263 |
+
shader_target_a = None
|
| 264 |
+
if "pose_mask" in data:
|
| 265 |
+
pred["shader"]["target_gt_a"] = data["pose_mask"]
|
| 266 |
+
shader_target_a = data["pose_mask"]
|
| 267 |
+
|
| 268 |
+
shader_target_sudp = None
|
| 269 |
+
if "pose_label" in data:
|
| 270 |
+
shader_target_sudp = data["pose_label"][:, :3, :, :]
|
| 271 |
+
|
| 272 |
+
if self.args.test_pose_use_parser_udp:
|
| 273 |
+
shader_target_sudp = None
|
| 274 |
+
if shader_target_sudp is None:
|
| 275 |
+
shader_target_sudp = pred["pose_parser"]["pred"][:, 0:3, :, :]
|
| 276 |
+
|
| 277 |
+
if shader_target_a is None:
|
| 278 |
+
shader_target_a = pred["pose_parser"]["pred"][:, 3:4, :, :]
|
| 279 |
+
|
| 280 |
+
# build x_target_sudp_a in the draw call
|
| 281 |
+
x_target_sudp_a = torch.cat((
|
| 282 |
+
shader_target_sudp*shader_target_a,
|
| 283 |
+
shader_target_a
|
| 284 |
+
), 1)
|
| 285 |
+
pred["shader"].update({
|
| 286 |
+
"x_target_sudp_a": x_target_sudp_a
|
| 287 |
+
})
|
| 288 |
+
_, features = self.traget_pose_encoder(
|
| 289 |
+
(x_target_sudp_a-0.6)/0.2970, ret_parser_out=False)
|
| 290 |
+
|
| 291 |
+
pred["shader"]["target_pose_features"] = features
|
| 292 |
+
return pred
|
data_loader.py
ADDED
|
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from torch.utils.data import Dataset
|
| 5 |
+
import os
|
| 6 |
+
cv2.setNumThreads(1)
|
| 7 |
+
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class RandomResizedCropWithAutoCenteringAndZeroPadding (object):
|
| 11 |
+
def __init__(self, output_size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), center_jitter=(0.1, 0.1), size_from_alpha_mask=True):
|
| 12 |
+
assert isinstance(output_size, (int, tuple))
|
| 13 |
+
if isinstance(output_size, int):
|
| 14 |
+
self.output_size = (output_size, output_size)
|
| 15 |
+
else:
|
| 16 |
+
assert len(output_size) == 2
|
| 17 |
+
self.output_size = output_size
|
| 18 |
+
assert isinstance(scale, tuple)
|
| 19 |
+
assert isinstance(ratio, tuple)
|
| 20 |
+
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
|
| 21 |
+
raise ValueError("Scale and ratio should be of kind (min, max)")
|
| 22 |
+
self.size_from_alpha_mask = size_from_alpha_mask
|
| 23 |
+
self.scale = scale
|
| 24 |
+
self.ratio = ratio
|
| 25 |
+
assert isinstance(center_jitter, tuple)
|
| 26 |
+
self.center_jitter = center_jitter
|
| 27 |
+
|
| 28 |
+
def __call__(self, sample):
|
| 29 |
+
imidx, image = sample['imidx'], sample["image_np"]
|
| 30 |
+
if "labels" in sample:
|
| 31 |
+
label = sample["labels"]
|
| 32 |
+
else:
|
| 33 |
+
label = None
|
| 34 |
+
|
| 35 |
+
im_h, im_w = image.shape[:2]
|
| 36 |
+
if self.size_from_alpha_mask and image.shape[2] == 4:
|
| 37 |
+
# compute bbox from alpha mask
|
| 38 |
+
bbox_left, bbox_top, bbox_w, bbox_h = cv2.boundingRect(
|
| 39 |
+
(image[:, :, 3] > 0).astype(np.uint8))
|
| 40 |
+
else:
|
| 41 |
+
bbox_left, bbox_top = 0, 0
|
| 42 |
+
bbox_h, bbox_w = image.shape[:2]
|
| 43 |
+
if bbox_h <= 1 and bbox_w <= 1:
|
| 44 |
+
sample["bad"] = 0
|
| 45 |
+
else:
|
| 46 |
+
# detect too small image here
|
| 47 |
+
alpha_varea = np.sum((image[:, :, 3] > 0).astype(np.uint8))
|
| 48 |
+
image_area = image.shape[0]*image.shape[1]
|
| 49 |
+
if alpha_varea/image_area < 0.001:
|
| 50 |
+
sample["bad"] = alpha_varea
|
| 51 |
+
# detect bad image
|
| 52 |
+
if "bad" in sample:
|
| 53 |
+
# baddata_dir = os.path.join(os.getcwd(), 'test_data', "baddata" + os.sep)
|
| 54 |
+
# save_output(str(imidx)+".png",image,label,baddata_dir)
|
| 55 |
+
bbox_h, bbox_w = image.shape[:2]
|
| 56 |
+
sample["image_np"] = np.zeros(
|
| 57 |
+
[self.output_size[0], self.output_size[1], image.shape[2]], dtype=image.dtype)
|
| 58 |
+
if label is not None:
|
| 59 |
+
sample["labels"] = np.zeros(
|
| 60 |
+
[self.output_size[0], self.output_size[1], 4], dtype=label.dtype)
|
| 61 |
+
|
| 62 |
+
return sample
|
| 63 |
+
|
| 64 |
+
# compute default area by making sure output_size contains bbox_w * bbox_h
|
| 65 |
+
|
| 66 |
+
jitter_h = np.random.uniform(-bbox_h *
|
| 67 |
+
self.center_jitter[0], bbox_h*self.center_jitter[0])
|
| 68 |
+
jitter_w = np.random.uniform(-bbox_w *
|
| 69 |
+
self.center_jitter[1], bbox_w*self.center_jitter[1])
|
| 70 |
+
|
| 71 |
+
# h/w
|
| 72 |
+
target_aspect_ratio = np.exp(
|
| 73 |
+
np.log(self.output_size[0]/self.output_size[1]) +
|
| 74 |
+
np.random.uniform(np.log(self.ratio[0]), np.log(self.ratio[1]))
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
source_aspect_ratio = bbox_h/bbox_w
|
| 78 |
+
|
| 79 |
+
if target_aspect_ratio < source_aspect_ratio:
|
| 80 |
+
# same w, target has larger h, use h to align
|
| 81 |
+
target_height = bbox_h * \
|
| 82 |
+
np.random.uniform(self.scale[0], self.scale[1])
|
| 83 |
+
virtual_h = int(
|
| 84 |
+
round(target_height))
|
| 85 |
+
virtual_w = int(
|
| 86 |
+
round(target_height / target_aspect_ratio)) # h/w
|
| 87 |
+
else:
|
| 88 |
+
# same w, source has larger h, use w to align
|
| 89 |
+
target_width = bbox_w * \
|
| 90 |
+
np.random.uniform(self.scale[0], self.scale[1])
|
| 91 |
+
virtual_h = int(
|
| 92 |
+
round(target_width * target_aspect_ratio)) # h/w
|
| 93 |
+
virtual_w = int(
|
| 94 |
+
round(target_width))
|
| 95 |
+
|
| 96 |
+
# print("required aspect ratio:", target_aspect_ratio)
|
| 97 |
+
|
| 98 |
+
virtual_top = int(round(bbox_top + jitter_h - (virtual_h-bbox_h)/2))
|
| 99 |
+
virutal_left = int(round(bbox_left + jitter_w - (virtual_w-bbox_w)/2))
|
| 100 |
+
|
| 101 |
+
if virtual_top < 0:
|
| 102 |
+
top_padding = abs(virtual_top)
|
| 103 |
+
crop_top = 0
|
| 104 |
+
else:
|
| 105 |
+
top_padding = 0
|
| 106 |
+
crop_top = virtual_top
|
| 107 |
+
if virutal_left < 0:
|
| 108 |
+
left_padding = abs(virutal_left)
|
| 109 |
+
crop_left = 0
|
| 110 |
+
else:
|
| 111 |
+
left_padding = 0
|
| 112 |
+
crop_left = virutal_left
|
| 113 |
+
if virtual_top+virtual_h > im_h:
|
| 114 |
+
bottom_padding = abs(im_h-(virtual_top+virtual_h))
|
| 115 |
+
crop_bottom = im_h
|
| 116 |
+
else:
|
| 117 |
+
bottom_padding = 0
|
| 118 |
+
crop_bottom = virtual_top+virtual_h
|
| 119 |
+
if virutal_left+virtual_w > im_w:
|
| 120 |
+
right_padding = abs(im_w-(virutal_left+virtual_w))
|
| 121 |
+
crop_right = im_w
|
| 122 |
+
else:
|
| 123 |
+
right_padding = 0
|
| 124 |
+
crop_right = virutal_left+virtual_w
|
| 125 |
+
# crop
|
| 126 |
+
|
| 127 |
+
image = image[crop_top:crop_bottom, crop_left: crop_right]
|
| 128 |
+
if label is not None:
|
| 129 |
+
label = label[crop_top:crop_bottom, crop_left: crop_right]
|
| 130 |
+
|
| 131 |
+
# pad
|
| 132 |
+
if top_padding + bottom_padding + left_padding + right_padding > 0:
|
| 133 |
+
padding = ((top_padding, bottom_padding),
|
| 134 |
+
(left_padding, right_padding), (0, 0))
|
| 135 |
+
# print("padding", padding)
|
| 136 |
+
image = np.pad(image, padding, mode='constant')
|
| 137 |
+
if label is not None:
|
| 138 |
+
label = np.pad(label, padding, mode='constant')
|
| 139 |
+
|
| 140 |
+
if image.shape[0]/image.shape[1] - virtual_h/virtual_w > 0.001:
|
| 141 |
+
print("virtual aspect ratio:", virtual_h/virtual_w)
|
| 142 |
+
print("image aspect ratio:", image.shape[0]/image.shape[1])
|
| 143 |
+
assert (image.shape[0]/image.shape[1] - virtual_h/virtual_w < 0.001)
|
| 144 |
+
sample["crop"] = np.array(
|
| 145 |
+
[im_h, im_w, crop_top, crop_bottom, crop_left, crop_right, top_padding, bottom_padding, left_padding, right_padding, image.shape[0], image.shape[1]])
|
| 146 |
+
|
| 147 |
+
# resize
|
| 148 |
+
if self.output_size[1] != image.shape[1] or self.output_size[0] != image.shape[0]:
|
| 149 |
+
if self.output_size[1] > image.shape[1] and self.output_size[0] > image.shape[0]:
|
| 150 |
+
# enlarging
|
| 151 |
+
image = cv2.resize(
|
| 152 |
+
image, (self.output_size[1], self.output_size[0]), interpolation=cv2.INTER_LINEAR)
|
| 153 |
+
else:
|
| 154 |
+
# shrinking
|
| 155 |
+
image = cv2.resize(
|
| 156 |
+
image, (self.output_size[1], self.output_size[0]), interpolation=cv2.INTER_AREA)
|
| 157 |
+
|
| 158 |
+
if label is not None:
|
| 159 |
+
label = cv2.resize(label, (self.output_size[1], self.output_size[0]),
|
| 160 |
+
interpolation=cv2.INTER_NEAREST_EXACT)
|
| 161 |
+
|
| 162 |
+
assert image.shape[0] == self.output_size[0] and image.shape[1] == self.output_size[1]
|
| 163 |
+
sample['imidx'], sample["image_np"] = imidx, image
|
| 164 |
+
if label is not None:
|
| 165 |
+
assert label.shape[0] == self.output_size[0] and label.shape[1] == self.output_size[1]
|
| 166 |
+
sample["labels"] = label
|
| 167 |
+
|
| 168 |
+
return sample
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
class FileDataset(Dataset):
|
| 172 |
+
def __init__(self, image_names_list, fg_img_lbl_transform=None, shader_pose_use_gt_udp_test=True, shader_target_use_gt_rgb_debug=False):
|
| 173 |
+
self.image_name_list = image_names_list
|
| 174 |
+
self.fg_img_lbl_transform = fg_img_lbl_transform
|
| 175 |
+
self.shader_pose_use_gt_udp_test = shader_pose_use_gt_udp_test
|
| 176 |
+
self.shader_target_use_gt_rgb_debug = shader_target_use_gt_rgb_debug
|
| 177 |
+
|
| 178 |
+
def __len__(self):
|
| 179 |
+
return len(self.image_name_list)
|
| 180 |
+
|
| 181 |
+
def get_gt_from_disk(self, idx, imname, read_label):
|
| 182 |
+
if read_label:
|
| 183 |
+
# read label
|
| 184 |
+
with open(imname, mode="rb") as bio:
|
| 185 |
+
if imname.find(".npz") > 0:
|
| 186 |
+
label_np = np.load(bio, allow_pickle=True)[
|
| 187 |
+
'i'].astype(np.float32, copy=False)
|
| 188 |
+
else:
|
| 189 |
+
label_np = cv2.cvtColor(cv2.imdecode(np.frombuffer(bio.read(
|
| 190 |
+
), np.uint8), cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH | cv2.IMREAD_UNCHANGED), cv2.COLOR_BGRA2RGBA)
|
| 191 |
+
assert (4 == label_np.shape[2])
|
| 192 |
+
# fake image out of valid label
|
| 193 |
+
image_np = (label_np*255).clip(0, 255).astype(np.uint8, copy=False)
|
| 194 |
+
# assemble sample
|
| 195 |
+
sample = {'imidx': np.array(
|
| 196 |
+
[idx]), "image_np": image_np, "labels": label_np}
|
| 197 |
+
|
| 198 |
+
else:
|
| 199 |
+
# read image as unit8
|
| 200 |
+
with open(imname, mode="rb") as bio:
|
| 201 |
+
image_np = cv2.cvtColor(cv2.imdecode(np.frombuffer(
|
| 202 |
+
bio.read(), np.uint8), cv2.IMREAD_UNCHANGED), cv2.COLOR_BGRA2RGBA)
|
| 203 |
+
# image_np = Image.open(bio)
|
| 204 |
+
# image_np = np.array(image_np)
|
| 205 |
+
assert (3 == len(image_np.shape))
|
| 206 |
+
if (image_np.shape[2] == 4):
|
| 207 |
+
mask_np = image_np[:, :, 3:4]
|
| 208 |
+
image_np = (image_np[:, :, :3] *
|
| 209 |
+
(image_np[:, :, 3][:, :, np.newaxis]/255.0)).clip(0, 255).astype(np.uint8, copy=False)
|
| 210 |
+
elif (image_np.shape[2] == 3):
|
| 211 |
+
# generate a fake mask
|
| 212 |
+
# Fool-proofing
|
| 213 |
+
mask_np = np.ones(
|
| 214 |
+
(image_np.shape[0], image_np.shape[1], 1), dtype=np.uint8)*255
|
| 215 |
+
print("WARN: transparent background is preferred for image ", imname)
|
| 216 |
+
else:
|
| 217 |
+
raise ValueError("weird shape of image ", imname, image_np)
|
| 218 |
+
image_np = np.concatenate((image_np, mask_np), axis=2)
|
| 219 |
+
sample = {'imidx': np.array(
|
| 220 |
+
[idx]), "image_np": image_np}
|
| 221 |
+
|
| 222 |
+
# apply fg_img_lbl_transform
|
| 223 |
+
if self.fg_img_lbl_transform:
|
| 224 |
+
sample = self.fg_img_lbl_transform(sample)
|
| 225 |
+
|
| 226 |
+
if "labels" in sample:
|
| 227 |
+
# return UDP as 4chn XYZV float tensor
|
| 228 |
+
sample["labels"] = torch.from_numpy(
|
| 229 |
+
sample["labels"].transpose((2, 0, 1)))
|
| 230 |
+
assert (sample["labels"].dtype == torch.float32)
|
| 231 |
+
|
| 232 |
+
if "image_np" in sample:
|
| 233 |
+
# return image as 3chn RGB uint8 tensor and 1chn A uint8 tensor
|
| 234 |
+
sample["mask"] = torch.from_numpy(
|
| 235 |
+
sample["image_np"][:, :, 3:4].transpose((2, 0, 1)))
|
| 236 |
+
assert (sample["mask"].dtype == torch.uint8)
|
| 237 |
+
sample["image"] = torch.from_numpy(
|
| 238 |
+
sample["image_np"][:, :, :3].transpose((2, 0, 1)))
|
| 239 |
+
|
| 240 |
+
assert (sample["image"].dtype == torch.uint8)
|
| 241 |
+
del sample["image_np"]
|
| 242 |
+
return sample
|
| 243 |
+
|
| 244 |
+
def __getitem__(self, idx):
|
| 245 |
+
sample = {
|
| 246 |
+
'imidx': np.array([idx])}
|
| 247 |
+
target = self.get_gt_from_disk(
|
| 248 |
+
idx, imname=self.image_name_list[idx][0], read_label=self.shader_pose_use_gt_udp_test)
|
| 249 |
+
if self.shader_target_use_gt_rgb_debug:
|
| 250 |
+
sample["pose_images"] = torch.stack([target["image"]])
|
| 251 |
+
sample["pose_mask"] = target["mask"]
|
| 252 |
+
elif self.shader_pose_use_gt_udp_test:
|
| 253 |
+
sample["pose_label"] = target["labels"]
|
| 254 |
+
sample["pose_mask"] = target["mask"]
|
| 255 |
+
else:
|
| 256 |
+
sample["pose_images"] = torch.stack([target["image"]])
|
| 257 |
+
if "crop" in target:
|
| 258 |
+
sample["pose_crop"] = target["crop"]
|
| 259 |
+
character_images = []
|
| 260 |
+
character_masks = []
|
| 261 |
+
for i in range(1, len(self.image_name_list[idx])):
|
| 262 |
+
source = self.get_gt_from_disk(
|
| 263 |
+
idx, self.image_name_list[idx][i], read_label=False)
|
| 264 |
+
character_images.append(source["image"])
|
| 265 |
+
character_masks.append(source["mask"])
|
| 266 |
+
character_images = torch.stack(character_images)
|
| 267 |
+
character_masks = torch.stack(character_masks)
|
| 268 |
+
sample.update({
|
| 269 |
+
"character_images": character_images,
|
| 270 |
+
"character_masks": character_masks
|
| 271 |
+
})
|
| 272 |
+
# do not make fake labels in inference
|
| 273 |
+
return sample
|
infer.sh
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
rm -r "./results"
|
| 2 |
+
mkdir "./results"
|
| 3 |
+
|
| 4 |
+
rlaunch --gpu=1 --cpu=4 --memory=25600 -- python3 -m torch.distributed.launch \
|
| 5 |
+
--nproc_per_node=1 train.py --mode=test \
|
| 6 |
+
--world_size=1 --dataloaders=2 \
|
| 7 |
+
--test_input_poses_images=./test_data/ \
|
| 8 |
+
--test_input_person_images=./character_sheet/ \
|
| 9 |
+
--test_output_dir=./results/ \
|
| 10 |
+
--test_checkpoint_dir=./weights/
|
| 11 |
+
|
| 12 |
+
echo Generating video...
|
| 13 |
+
ffmpeg -r 30 -y -i ./results/%d.png -r 30 -c:v libx264 output.mp4 -r 30
|
| 14 |
+
echo DONE.
|
model/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
model/backbone.py
ADDED
|
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torchvision import models
|
| 4 |
+
from torch.nn import functional as F
|
| 5 |
+
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch
|
| 8 |
+
from torchvision import models
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class AdaptiveConcatPool2d(nn.Module):
|
| 12 |
+
"""
|
| 13 |
+
Layer that concats `AdaptiveAvgPool2d` and `AdaptiveMaxPool2d`.
|
| 14 |
+
Source: Fastai. This code was taken from the fastai library at url
|
| 15 |
+
https://github.com/fastai/fastai/blob/master/fastai/layers.py#L176
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, sz=None):
|
| 19 |
+
"Output will be 2*sz or 2 if sz is None"
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.output_size = sz or 1
|
| 22 |
+
self.ap = nn.AdaptiveAvgPool2d(self.output_size)
|
| 23 |
+
self.mp = nn.AdaptiveMaxPool2d(self.output_size)
|
| 24 |
+
|
| 25 |
+
def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class MyNorm(nn.Module):
|
| 29 |
+
def __init__(self, num_channels):
|
| 30 |
+
super(MyNorm, self).__init__()
|
| 31 |
+
self.norm = nn.InstanceNorm2d(
|
| 32 |
+
num_channels, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
| 33 |
+
|
| 34 |
+
def forward(self, x):
|
| 35 |
+
x = self.norm(x)
|
| 36 |
+
return x
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def resnet_fastai(model, pretrained, url, replace_first_layer=None, replace_maxpool_layer=None, progress=True, map_location=None, **kwargs):
|
| 40 |
+
cut = -2
|
| 41 |
+
s = model(pretrained=False, **kwargs)
|
| 42 |
+
if replace_maxpool_layer is not None:
|
| 43 |
+
s.maxpool = replace_maxpool_layer
|
| 44 |
+
if replace_first_layer is not None:
|
| 45 |
+
body = nn.Sequential(replace_first_layer, *list(s.children())[1:cut])
|
| 46 |
+
else:
|
| 47 |
+
body = nn.Sequential(*list(s.children())[:cut])
|
| 48 |
+
|
| 49 |
+
if pretrained:
|
| 50 |
+
state = torch.hub.load_state_dict_from_url(url,
|
| 51 |
+
progress=progress, map_location=map_location)
|
| 52 |
+
if replace_first_layer is not None:
|
| 53 |
+
for each in list(state.keys()).copy():
|
| 54 |
+
if each.find("0.0.") == 0:
|
| 55 |
+
del state[each]
|
| 56 |
+
body_tail = nn.Sequential(body)
|
| 57 |
+
ret = body_tail.load_state_dict(state, strict=False)
|
| 58 |
+
return body
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def get_backbone(name, pretrained=True, map_location=None):
|
| 62 |
+
""" Loading backbone, defining names for skip-connections and encoder output. """
|
| 63 |
+
|
| 64 |
+
first_layer_for_4chn = nn.Conv2d(
|
| 65 |
+
4, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
| 66 |
+
max_pool_layer_replace = nn.Conv2d(
|
| 67 |
+
64, 64, kernel_size=3, stride=2, padding=1, bias=False)
|
| 68 |
+
# loading backbone model
|
| 69 |
+
if name == 'resnet18':
|
| 70 |
+
backbone = models.resnet18(pretrained=pretrained)
|
| 71 |
+
if name == 'resnet18-4':
|
| 72 |
+
backbone = models.resnet18(pretrained=pretrained)
|
| 73 |
+
backbone.conv1 = first_layer_for_4chn
|
| 74 |
+
elif name == 'resnet34':
|
| 75 |
+
backbone = models.resnet34(pretrained=pretrained)
|
| 76 |
+
elif name == 'resnet50':
|
| 77 |
+
backbone = models.resnet50(pretrained=False, norm_layer=MyNorm)
|
| 78 |
+
backbone.maxpool = max_pool_layer_replace
|
| 79 |
+
elif name == 'resnet101':
|
| 80 |
+
backbone = models.resnet101(pretrained=pretrained)
|
| 81 |
+
elif name == 'resnet152':
|
| 82 |
+
backbone = models.resnet152(pretrained=pretrained)
|
| 83 |
+
elif name == 'vgg16':
|
| 84 |
+
backbone = models.vgg16_bn(pretrained=pretrained).features
|
| 85 |
+
elif name == 'vgg19':
|
| 86 |
+
backbone = models.vgg19_bn(pretrained=pretrained).features
|
| 87 |
+
elif name == 'resnet18_danbo-4':
|
| 88 |
+
backbone = resnet_fastai(models.resnet18, url="https://github.com/RF5/danbooru-pretrained/releases/download/v0.1/resnet18-3f77756f.pth",
|
| 89 |
+
pretrained=pretrained, map_location=map_location, norm_layer=MyNorm, replace_first_layer=first_layer_for_4chn)
|
| 90 |
+
elif name == 'resnet50_danbo':
|
| 91 |
+
backbone = resnet_fastai(models.resnet50, url="https://github.com/RF5/danbooru-pretrained/releases/download/v0.1/resnet50-13306192.pth",
|
| 92 |
+
pretrained=pretrained, map_location=map_location, norm_layer=MyNorm, replace_maxpool_layer=max_pool_layer_replace)
|
| 93 |
+
elif name == 'densenet121':
|
| 94 |
+
backbone = models.densenet121(pretrained=True).features
|
| 95 |
+
elif name == 'densenet161':
|
| 96 |
+
backbone = models.densenet161(pretrained=True).features
|
| 97 |
+
elif name == 'densenet169':
|
| 98 |
+
backbone = models.densenet169(pretrained=True).features
|
| 99 |
+
elif name == 'densenet201':
|
| 100 |
+
backbone = models.densenet201(pretrained=True).features
|
| 101 |
+
else:
|
| 102 |
+
raise NotImplemented(
|
| 103 |
+
'{} backbone model is not implemented so far.'.format(name))
|
| 104 |
+
#print(backbone)
|
| 105 |
+
# specifying skip feature and output names
|
| 106 |
+
if name.startswith('resnet'):
|
| 107 |
+
feature_names = [None, 'relu', 'layer1', 'layer2', 'layer3']
|
| 108 |
+
backbone_output = 'layer4'
|
| 109 |
+
elif name == 'vgg16':
|
| 110 |
+
# TODO: consider using a 'bridge' for VGG models, there is just a MaxPool between last skip and backbone output
|
| 111 |
+
feature_names = ['5', '12', '22', '32', '42']
|
| 112 |
+
backbone_output = '43'
|
| 113 |
+
elif name == 'vgg19':
|
| 114 |
+
feature_names = ['5', '12', '25', '38', '51']
|
| 115 |
+
backbone_output = '52'
|
| 116 |
+
elif name.startswith('densenet'):
|
| 117 |
+
feature_names = [None, 'relu0', 'denseblock1',
|
| 118 |
+
'denseblock2', 'denseblock3']
|
| 119 |
+
backbone_output = 'denseblock4'
|
| 120 |
+
elif name == 'unet_encoder':
|
| 121 |
+
feature_names = ['module1', 'module2', 'module3', 'module4']
|
| 122 |
+
backbone_output = 'module5'
|
| 123 |
+
else:
|
| 124 |
+
raise NotImplemented(
|
| 125 |
+
'{} backbone model is not implemented so far.'.format(name))
|
| 126 |
+
if name.find('_danbo') > 0:
|
| 127 |
+
feature_names = [None, '2', '4', '5', '6']
|
| 128 |
+
backbone_output = '7'
|
| 129 |
+
return backbone, feature_names, backbone_output
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class UpsampleBlock(nn.Module):
|
| 133 |
+
|
| 134 |
+
# TODO: separate parametric and non-parametric classes?
|
| 135 |
+
# TODO: skip connection concatenated OR added
|
| 136 |
+
|
| 137 |
+
def __init__(self, ch_in, ch_out=None, skip_in=0, use_bn=True, parametric=False):
|
| 138 |
+
super(UpsampleBlock, self).__init__()
|
| 139 |
+
|
| 140 |
+
self.parametric = parametric
|
| 141 |
+
ch_out = ch_in/2 if ch_out is None else ch_out
|
| 142 |
+
|
| 143 |
+
# first convolution: either transposed conv, or conv following the skip connection
|
| 144 |
+
if parametric:
|
| 145 |
+
# versions: kernel=4 padding=1, kernel=2 padding=0
|
| 146 |
+
self.up = nn.ConvTranspose2d(in_channels=ch_in, out_channels=ch_out, kernel_size=(4, 4),
|
| 147 |
+
stride=2, padding=1, output_padding=0, bias=(not use_bn))
|
| 148 |
+
self.bn1 = MyNorm(ch_out) if use_bn else None
|
| 149 |
+
else:
|
| 150 |
+
self.up = None
|
| 151 |
+
ch_in = ch_in + skip_in
|
| 152 |
+
self.conv1 = nn.Conv2d(in_channels=ch_in, out_channels=ch_out, kernel_size=(3, 3),
|
| 153 |
+
stride=1, padding=1, bias=(not use_bn))
|
| 154 |
+
self.bn1 = MyNorm(ch_out) if use_bn else None
|
| 155 |
+
|
| 156 |
+
self.relu = nn.ReLU(inplace=True)
|
| 157 |
+
|
| 158 |
+
# second convolution
|
| 159 |
+
conv2_in = ch_out if not parametric else ch_out + skip_in
|
| 160 |
+
self.conv2 = nn.Conv2d(in_channels=conv2_in, out_channels=ch_out, kernel_size=(3, 3),
|
| 161 |
+
stride=1, padding=1, bias=(not use_bn))
|
| 162 |
+
self.bn2 = MyNorm(ch_out) if use_bn else None
|
| 163 |
+
|
| 164 |
+
def forward(self, x, skip_connection=None):
|
| 165 |
+
|
| 166 |
+
x = self.up(x) if self.parametric else F.interpolate(x, size=None, scale_factor=2, mode='bilinear',
|
| 167 |
+
align_corners=None)
|
| 168 |
+
if self.parametric:
|
| 169 |
+
x = self.bn1(x) if self.bn1 is not None else x
|
| 170 |
+
x = self.relu(x)
|
| 171 |
+
|
| 172 |
+
if skip_connection is not None:
|
| 173 |
+
x = torch.cat([x, skip_connection], dim=1)
|
| 174 |
+
|
| 175 |
+
if not self.parametric:
|
| 176 |
+
x = self.conv1(x)
|
| 177 |
+
x = self.bn1(x) if self.bn1 is not None else x
|
| 178 |
+
x = self.relu(x)
|
| 179 |
+
x = self.conv2(x)
|
| 180 |
+
x = self.bn2(x) if self.bn2 is not None else x
|
| 181 |
+
x = self.relu(x)
|
| 182 |
+
|
| 183 |
+
return x
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class ResEncUnet(nn.Module):
|
| 187 |
+
|
| 188 |
+
""" U-Net (https://arxiv.org/pdf/1505.04597.pdf) implementation with pre-trained torchvision backbones."""
|
| 189 |
+
|
| 190 |
+
def __init__(self,
|
| 191 |
+
backbone_name,
|
| 192 |
+
pretrained=True,
|
| 193 |
+
encoder_freeze=False,
|
| 194 |
+
classes=21,
|
| 195 |
+
decoder_filters=(512, 256, 128, 64, 32),
|
| 196 |
+
parametric_upsampling=True,
|
| 197 |
+
shortcut_features='default',
|
| 198 |
+
decoder_use_instancenorm=True,
|
| 199 |
+
map_location=None
|
| 200 |
+
):
|
| 201 |
+
super(ResEncUnet, self).__init__()
|
| 202 |
+
|
| 203 |
+
self.backbone_name = backbone_name
|
| 204 |
+
|
| 205 |
+
self.backbone, self.shortcut_features, self.bb_out_name = get_backbone(
|
| 206 |
+
backbone_name, pretrained=pretrained, map_location=map_location)
|
| 207 |
+
shortcut_chs, bb_out_chs = self.infer_skip_channels()
|
| 208 |
+
if shortcut_features != 'default':
|
| 209 |
+
self.shortcut_features = shortcut_features
|
| 210 |
+
|
| 211 |
+
# build decoder part
|
| 212 |
+
self.upsample_blocks = nn.ModuleList()
|
| 213 |
+
# avoiding having more blocks than skip connections
|
| 214 |
+
decoder_filters = decoder_filters[:len(self.shortcut_features)]
|
| 215 |
+
decoder_filters_in = [bb_out_chs] + list(decoder_filters[:-1])
|
| 216 |
+
num_blocks = len(self.shortcut_features)
|
| 217 |
+
for i, [filters_in, filters_out] in enumerate(zip(decoder_filters_in, decoder_filters)):
|
| 218 |
+
self.upsample_blocks.append(UpsampleBlock(filters_in, filters_out,
|
| 219 |
+
skip_in=shortcut_chs[num_blocks-i-1],
|
| 220 |
+
parametric=parametric_upsampling,
|
| 221 |
+
use_bn=decoder_use_instancenorm))
|
| 222 |
+
self.final_conv = nn.Conv2d(
|
| 223 |
+
decoder_filters[-1], classes, kernel_size=(1, 1))
|
| 224 |
+
|
| 225 |
+
if encoder_freeze:
|
| 226 |
+
self.freeze_encoder()
|
| 227 |
+
|
| 228 |
+
def freeze_encoder(self):
|
| 229 |
+
""" Freezing encoder parameters, the newly initialized decoder parameters are remaining trainable. """
|
| 230 |
+
|
| 231 |
+
for param in self.backbone.parameters():
|
| 232 |
+
param.requires_grad = False
|
| 233 |
+
|
| 234 |
+
def forward(self, *input, ret_parser_out=True):
|
| 235 |
+
""" Forward propagation in U-Net. """
|
| 236 |
+
|
| 237 |
+
x, features = self.forward_backbone(*input)
|
| 238 |
+
output_feature = [x]
|
| 239 |
+
for skip_name, upsample_block in zip(self.shortcut_features[::-1], self.upsample_blocks):
|
| 240 |
+
skip_features = features[skip_name]
|
| 241 |
+
if skip_features is not None:
|
| 242 |
+
output_feature.append(skip_features)
|
| 243 |
+
if ret_parser_out:
|
| 244 |
+
x = upsample_block(x, skip_features)
|
| 245 |
+
if ret_parser_out:
|
| 246 |
+
x = self.final_conv(x)
|
| 247 |
+
# apply sigmoid later
|
| 248 |
+
else:
|
| 249 |
+
x = None
|
| 250 |
+
|
| 251 |
+
return x, output_feature
|
| 252 |
+
|
| 253 |
+
def forward_backbone(self, x):
|
| 254 |
+
""" Forward propagation in backbone encoder network. """
|
| 255 |
+
|
| 256 |
+
features = {None: None} if None in self.shortcut_features else dict()
|
| 257 |
+
for name, child in self.backbone.named_children():
|
| 258 |
+
x = child(x)
|
| 259 |
+
if name in self.shortcut_features:
|
| 260 |
+
features[name] = x
|
| 261 |
+
if name == self.bb_out_name:
|
| 262 |
+
break
|
| 263 |
+
|
| 264 |
+
return x, features
|
| 265 |
+
|
| 266 |
+
def infer_skip_channels(self):
|
| 267 |
+
""" Getting the number of channels at skip connections and at the output of the encoder. """
|
| 268 |
+
if self.backbone_name.find("-4") > 0:
|
| 269 |
+
x = torch.zeros(1, 4, 224, 224)
|
| 270 |
+
else:
|
| 271 |
+
x = torch.zeros(1, 3, 224, 224)
|
| 272 |
+
has_fullres_features = self.backbone_name.startswith(
|
| 273 |
+
'vgg') or self.backbone_name == 'unet_encoder'
|
| 274 |
+
# only VGG has features at full resolution
|
| 275 |
+
channels = [] if has_fullres_features else [0]
|
| 276 |
+
|
| 277 |
+
# forward run in backbone to count channels (dirty solution but works for *any* Module)
|
| 278 |
+
for name, child in self.backbone.named_children():
|
| 279 |
+
x = child(x)
|
| 280 |
+
if name in self.shortcut_features:
|
| 281 |
+
channels.append(x.shape[1])
|
| 282 |
+
if name == self.bb_out_name:
|
| 283 |
+
out_channels = x.shape[1]
|
| 284 |
+
break
|
| 285 |
+
return channels, out_channels
|
model/decoder_small.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
from torch import nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class ResBlock2d(nn.Module):
|
| 9 |
+
def __init__(self, in_features, kernel_size, padding):
|
| 10 |
+
super(ResBlock2d, self).__init__()
|
| 11 |
+
self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
|
| 12 |
+
padding=padding)
|
| 13 |
+
self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
|
| 14 |
+
padding=padding)
|
| 15 |
+
|
| 16 |
+
self.norm1 = nn.Conv2d(
|
| 17 |
+
in_channels=in_features, out_channels=in_features, kernel_size=1)
|
| 18 |
+
self.norm2 = nn.Conv2d(
|
| 19 |
+
in_channels=in_features, out_channels=in_features, kernel_size=1)
|
| 20 |
+
|
| 21 |
+
def forward(self, x):
|
| 22 |
+
out = self.norm1(x)
|
| 23 |
+
out = F.relu(out, inplace=True)
|
| 24 |
+
out = self.conv1(out)
|
| 25 |
+
out = self.norm2(out)
|
| 26 |
+
out = F.relu(out, inplace=True)
|
| 27 |
+
out = self.conv2(out)
|
| 28 |
+
out += x
|
| 29 |
+
return out
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class RGBADecoderNet(nn.Module):
|
| 33 |
+
def __init__(self, c=64, out_planes=4, num_bottleneck_blocks=1):
|
| 34 |
+
super(RGBADecoderNet, self).__init__()
|
| 35 |
+
self.conv_rgba = nn.Sequential(nn.Conv2d(c, out_planes, kernel_size=3, stride=1,
|
| 36 |
+
padding=1, dilation=1, bias=True))
|
| 37 |
+
self.bottleneck = torch.nn.Sequential()
|
| 38 |
+
for i in range(num_bottleneck_blocks):
|
| 39 |
+
self.bottleneck.add_module(
|
| 40 |
+
'r' + str(i), ResBlock2d(c, kernel_size=(3, 3), padding=(1, 1)))
|
| 41 |
+
|
| 42 |
+
def forward(self, features_weighted_mask_atfeaturesscale_list=[]):
|
| 43 |
+
return torch.sigmoid(self.conv_rgba(self.bottleneck(features_weighted_mask_atfeaturesscale_list.pop(0))))
|
model/shader.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from .warplayer import warp_features
|
| 5 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class DecoderBlock(nn.Module):
|
| 9 |
+
def __init__(self, in_planes, c=224, out_msgs=0, out_locals=0, block_nums=1, out_masks=1, out_local_flows=32, out_msgs_flows=32, out_feat_flows=0):
|
| 10 |
+
|
| 11 |
+
super(DecoderBlock, self).__init__()
|
| 12 |
+
self.conv0 = nn.Sequential(
|
| 13 |
+
nn.Conv2d(in_planes, c, 3, 2, 1),
|
| 14 |
+
nn.PReLU(c),
|
| 15 |
+
nn.Conv2d(c, c, 3, 2, 1),
|
| 16 |
+
nn.PReLU(c),
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
self.convblocks = nn.ModuleList()
|
| 20 |
+
for i in range(block_nums):
|
| 21 |
+
self.convblocks.append(nn.Sequential(
|
| 22 |
+
nn.Conv2d(c, c, 3, 1, 1),
|
| 23 |
+
nn.PReLU(c),
|
| 24 |
+
nn.Conv2d(c, c, 3, 1, 1),
|
| 25 |
+
nn.PReLU(c),
|
| 26 |
+
nn.Conv2d(c, c, 3, 1, 1),
|
| 27 |
+
nn.PReLU(c),
|
| 28 |
+
nn.Conv2d(c, c, 3, 1, 1),
|
| 29 |
+
nn.PReLU(c),
|
| 30 |
+
nn.Conv2d(c, c, 3, 1, 1),
|
| 31 |
+
nn.PReLU(c),
|
| 32 |
+
nn.Conv2d(c, c, 3, 1, 1),
|
| 33 |
+
nn.PReLU(c),
|
| 34 |
+
))
|
| 35 |
+
self.out_flows = 2
|
| 36 |
+
self.out_msgs = out_msgs
|
| 37 |
+
self.out_msgs_flows = out_msgs_flows if out_msgs > 0 else 0
|
| 38 |
+
self.out_locals = out_locals
|
| 39 |
+
self.out_local_flows = out_local_flows if out_locals > 0 else 0
|
| 40 |
+
self.out_masks = out_masks
|
| 41 |
+
self.out_feat_flows = out_feat_flows
|
| 42 |
+
|
| 43 |
+
self.conv_last = nn.Sequential(
|
| 44 |
+
nn.ConvTranspose2d(c, c, 4, 2, 1),
|
| 45 |
+
nn.PReLU(c),
|
| 46 |
+
nn.ConvTranspose2d(c, self.out_flows+self.out_msgs+self.out_msgs_flows +
|
| 47 |
+
self.out_locals+self.out_local_flows+self.out_masks+self.out_feat_flows, 4, 2, 1),
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
def forward(self, accumulated_flow, *other):
|
| 51 |
+
x = [accumulated_flow]
|
| 52 |
+
for each in other:
|
| 53 |
+
if each is not None:
|
| 54 |
+
assert(accumulated_flow.shape[-1] == each.shape[-1]), "decoder want {}, but get {}".format(
|
| 55 |
+
accumulated_flow.shape, each.shape)
|
| 56 |
+
x.append(each)
|
| 57 |
+
feat = self.conv0(torch.cat(x, dim=1))
|
| 58 |
+
for convblock1 in self.convblocks:
|
| 59 |
+
feat = convblock1(feat) + feat
|
| 60 |
+
feat = self.conv_last(feat)
|
| 61 |
+
prev = 0
|
| 62 |
+
flow = feat[:, prev:prev+self.out_flows, :, :]
|
| 63 |
+
prev += self.out_flows
|
| 64 |
+
message = feat[:, prev:prev+self.out_msgs,
|
| 65 |
+
:, :] if self.out_msgs > 0 else None
|
| 66 |
+
prev += self.out_msgs
|
| 67 |
+
message_flow = feat[:, prev:prev + self.out_msgs_flows,
|
| 68 |
+
:, :] if self.out_msgs_flows > 0 else None
|
| 69 |
+
prev += self.out_msgs_flows
|
| 70 |
+
local_message = feat[:, prev:prev + self.out_locals,
|
| 71 |
+
:, :] if self.out_locals > 0 else None
|
| 72 |
+
prev += self.out_locals
|
| 73 |
+
local_message_flow = feat[:, prev:prev+self.out_local_flows,
|
| 74 |
+
:, :] if self.out_local_flows > 0 else None
|
| 75 |
+
prev += self.out_local_flows
|
| 76 |
+
mask = torch.sigmoid(
|
| 77 |
+
feat[:, prev:prev+self.out_masks, :, :]) if self.out_masks > 0 else None
|
| 78 |
+
prev += self.out_masks
|
| 79 |
+
feat_flow = feat[:, prev:prev+self.out_feat_flows,
|
| 80 |
+
:, :] if self.out_feat_flows > 0 else None
|
| 81 |
+
prev += self.out_feat_flows
|
| 82 |
+
return flow, mask, message, message_flow, local_message, local_message_flow, feat_flow
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class CINN(nn.Module):
|
| 86 |
+
def __init__(self, DIM_SHADER_REFERENCE, target_feature_chns=[512, 256, 128, 64, 64], feature_chns=[2048, 1024, 512, 256, 64], out_msgs_chn=[2048, 1024, 512, 256, 64, 64], out_locals_chn=[2048, 1024, 512, 256, 64, 0], block_num=[1, 1, 1, 1, 1, 2], block_chn_num=[224, 224, 224, 224, 224, 224]):
|
| 87 |
+
super(CINN, self).__init__()
|
| 88 |
+
|
| 89 |
+
self.in_msgs_chn = [0, *out_msgs_chn[:-1]]
|
| 90 |
+
self.in_locals_chn = [0, *out_locals_chn[:-1]]
|
| 91 |
+
|
| 92 |
+
self.decoder_blocks = nn.ModuleList()
|
| 93 |
+
self.feed_weighted = True
|
| 94 |
+
if self.feed_weighted:
|
| 95 |
+
in_planes = 2+2+DIM_SHADER_REFERENCE*2
|
| 96 |
+
else:
|
| 97 |
+
in_planes = 2+DIM_SHADER_REFERENCE
|
| 98 |
+
for each_target_feature_chns, each_feature_chns, each_out_msgs_chn, each_out_locals_chn, each_in_msgs_chn, each_in_locals_chn, each_block_num, each_block_chn_num in zip(target_feature_chns, feature_chns, out_msgs_chn, out_locals_chn, self.in_msgs_chn, self.in_locals_chn, block_num, block_chn_num):
|
| 99 |
+
self.decoder_blocks.append(
|
| 100 |
+
DecoderBlock(in_planes+each_target_feature_chns+each_feature_chns+each_in_locals_chn+each_in_msgs_chn, c=each_block_chn_num, block_nums=each_block_num, out_msgs=each_out_msgs_chn, out_locals=each_out_locals_chn, out_masks=2+each_out_locals_chn))
|
| 101 |
+
for i in range(len(feature_chns), len(out_locals_chn)):
|
| 102 |
+
#print("append extra block", i, "msg",
|
| 103 |
+
# out_msgs_chn[i], "local", out_locals_chn[i], "block", block_num[i])
|
| 104 |
+
self.decoder_blocks.append(
|
| 105 |
+
DecoderBlock(in_planes+self.in_msgs_chn[i]+self.in_locals_chn[i], c=block_chn_num[i], block_nums=block_num[i], out_msgs=out_msgs_chn[i], out_locals=out_locals_chn[i], out_masks=2+out_msgs_chn[i], out_feat_flows=0))
|
| 106 |
+
|
| 107 |
+
def apply_flow(self, mask, message, message_flow, local_message, local_message_flow, x_reference, accumulated_flow, each_x_reference_features=None, each_x_reference_features_flow=None):
|
| 108 |
+
if each_x_reference_features is not None:
|
| 109 |
+
size_from = each_x_reference_features
|
| 110 |
+
else:
|
| 111 |
+
size_from = x_reference
|
| 112 |
+
f_size = (size_from.shape[2], size_from.shape[3])
|
| 113 |
+
accumulated_flow = self.flow_rescale(
|
| 114 |
+
accumulated_flow, size_from)
|
| 115 |
+
# mask = warp_features(F.interpolate(
|
| 116 |
+
# mask, size=f_size, mode="bilinear"), accumulated_flow) if mask is not None else None
|
| 117 |
+
mask = F.interpolate(
|
| 118 |
+
mask, size=f_size, mode="bilinear") if mask is not None else None
|
| 119 |
+
message = F.interpolate(
|
| 120 |
+
message, size=f_size, mode="bilinear") if message is not None else None
|
| 121 |
+
message_flow = self.flow_rescale(
|
| 122 |
+
message_flow, size_from) if message_flow is not None else None
|
| 123 |
+
message = warp_features(
|
| 124 |
+
message, message_flow) if message_flow is not None else message
|
| 125 |
+
|
| 126 |
+
local_message = F.interpolate(
|
| 127 |
+
local_message, size=f_size, mode="bilinear") if local_message is not None else None
|
| 128 |
+
local_message_flow = self.flow_rescale(
|
| 129 |
+
local_message_flow, size_from) if local_message_flow is not None else None
|
| 130 |
+
local_message = warp_features(
|
| 131 |
+
local_message, local_message_flow) if local_message_flow is not None else local_message
|
| 132 |
+
|
| 133 |
+
warp_x_reference = warp_features(F.interpolate(
|
| 134 |
+
x_reference, size=f_size, mode="bilinear"), accumulated_flow)
|
| 135 |
+
|
| 136 |
+
each_x_reference_features_flow = self.flow_rescale(
|
| 137 |
+
each_x_reference_features_flow, size_from) if (each_x_reference_features is not None and each_x_reference_features_flow is not None) else None
|
| 138 |
+
warp_each_x_reference_features = warp_features(
|
| 139 |
+
each_x_reference_features, each_x_reference_features_flow) if each_x_reference_features_flow is not None else each_x_reference_features
|
| 140 |
+
|
| 141 |
+
return mask, message, local_message, warp_x_reference, accumulated_flow, warp_each_x_reference_features, each_x_reference_features_flow
|
| 142 |
+
|
| 143 |
+
def forward(self, x_target_features=[], x_reference=None, x_reference_features=[]):
|
| 144 |
+
y_flow = []
|
| 145 |
+
y_feat_flow = []
|
| 146 |
+
|
| 147 |
+
y_local_message = []
|
| 148 |
+
y_warp_x_reference = []
|
| 149 |
+
y_warp_x_reference_features = []
|
| 150 |
+
|
| 151 |
+
y_weighted_flow = []
|
| 152 |
+
y_weighted_mask = []
|
| 153 |
+
y_weighted_message = []
|
| 154 |
+
y_weighted_x_reference = []
|
| 155 |
+
y_weighted_x_reference_features = []
|
| 156 |
+
|
| 157 |
+
for pyrlevel, ifblock in enumerate(self.decoder_blocks):
|
| 158 |
+
stacked_wref = []
|
| 159 |
+
stacked_feat = []
|
| 160 |
+
stacked_anci = []
|
| 161 |
+
stacked_flow = []
|
| 162 |
+
stacked_mask = []
|
| 163 |
+
stacked_mesg = []
|
| 164 |
+
stacked_locm = []
|
| 165 |
+
stacked_feat_flow = []
|
| 166 |
+
for view_id in range(x_reference.shape[1]): # NMCHW
|
| 167 |
+
|
| 168 |
+
if pyrlevel == 0:
|
| 169 |
+
# create from zero flow
|
| 170 |
+
feat_ev = x_reference_features[pyrlevel][:,
|
| 171 |
+
view_id, :, :, :] if pyrlevel < len(x_reference_features) else None
|
| 172 |
+
|
| 173 |
+
accumulated_flow = torch.zeros_like(
|
| 174 |
+
feat_ev[:, :2, :, :]).to(device)
|
| 175 |
+
accumulated_feat_flow = torch.zeros_like(
|
| 176 |
+
feat_ev[:, :32, :, :]).to(device)
|
| 177 |
+
# domestic inputs
|
| 178 |
+
warp_x_reference = F.interpolate(x_reference[:, view_id, :, :, :], size=(
|
| 179 |
+
feat_ev.shape[-2], feat_ev.shape[-1]), mode="bilinear")
|
| 180 |
+
warp_x_reference_features = feat_ev
|
| 181 |
+
|
| 182 |
+
local_message = None
|
| 183 |
+
# federated inputs
|
| 184 |
+
weighted_flow = accumulated_flow if self.feed_weighted else None
|
| 185 |
+
weighted_wref = warp_x_reference if self.feed_weighted else None
|
| 186 |
+
weighted_message = None
|
| 187 |
+
else:
|
| 188 |
+
# resume from last layer
|
| 189 |
+
accumulated_flow = y_flow[-1][:, view_id, :, :, :]
|
| 190 |
+
accumulated_feat_flow = y_feat_flow[-1][:,
|
| 191 |
+
view_id, :, :, :] if y_feat_flow[-1] is not None else None
|
| 192 |
+
# domestic inputs
|
| 193 |
+
warp_x_reference = y_warp_x_reference[-1][:,
|
| 194 |
+
view_id, :, :, :]
|
| 195 |
+
warp_x_reference_features = y_warp_x_reference_features[-1][:,
|
| 196 |
+
view_id, :, :, :] if y_warp_x_reference_features[-1] is not None else None
|
| 197 |
+
local_message = y_local_message[-1][:, view_id, :,
|
| 198 |
+
:, :] if len(y_local_message) > 0 else None
|
| 199 |
+
|
| 200 |
+
# federated inputs
|
| 201 |
+
weighted_flow = y_weighted_flow[-1] if self.feed_weighted else None
|
| 202 |
+
weighted_wref = y_weighted_x_reference[-1] if self.feed_weighted else None
|
| 203 |
+
weighted_message = y_weighted_message[-1] if len(
|
| 204 |
+
y_weighted_message) > 0 else None
|
| 205 |
+
scaled_x_target = x_target_features[pyrlevel][:, :, :, :].detach() if pyrlevel < len(
|
| 206 |
+
x_target_features) else None
|
| 207 |
+
# compute flow
|
| 208 |
+
residual_flow, mask, message, message_flow, local_message, local_message_flow, residual_feat_flow = ifblock(
|
| 209 |
+
accumulated_flow, scaled_x_target, warp_x_reference, warp_x_reference_features, weighted_flow, weighted_wref, weighted_message, local_message)
|
| 210 |
+
accumulated_flow = residual_flow + accumulated_flow
|
| 211 |
+
accumulated_feat_flow = accumulated_flow
|
| 212 |
+
|
| 213 |
+
feat_ev = x_reference_features[pyrlevel+1][:,
|
| 214 |
+
view_id, :, :, :] if pyrlevel+1 < len(x_reference_features) else None
|
| 215 |
+
mask, message, local_message, warp_x_reference, accumulated_flow, warp_x_reference_features, accumulated_feat_flow = self.apply_flow(
|
| 216 |
+
mask, message, message_flow, local_message, local_message_flow, x_reference[:, view_id, :, :, :], accumulated_flow, feat_ev, accumulated_feat_flow)
|
| 217 |
+
stacked_flow.append(accumulated_flow)
|
| 218 |
+
if accumulated_feat_flow is not None:
|
| 219 |
+
stacked_feat_flow.append(accumulated_feat_flow)
|
| 220 |
+
stacked_mask.append(mask)
|
| 221 |
+
if message is not None:
|
| 222 |
+
stacked_mesg.append(message)
|
| 223 |
+
if local_message is not None:
|
| 224 |
+
stacked_locm.append(local_message)
|
| 225 |
+
stacked_wref.append(warp_x_reference)
|
| 226 |
+
if warp_x_reference_features is not None:
|
| 227 |
+
stacked_feat.append(warp_x_reference_features)
|
| 228 |
+
|
| 229 |
+
stacked_flow = torch.stack(stacked_flow, dim=1) # M*NCHW -> NMCHW
|
| 230 |
+
stacked_feat_flow = torch.stack(stacked_feat_flow, dim=1) if len(
|
| 231 |
+
stacked_feat_flow) > 0 else None
|
| 232 |
+
stacked_mask = torch.stack(
|
| 233 |
+
stacked_mask, dim=1)
|
| 234 |
+
|
| 235 |
+
stacked_mesg = torch.stack(stacked_mesg, dim=1) if len(
|
| 236 |
+
stacked_mesg) > 0 else None
|
| 237 |
+
stacked_locm = torch.stack(stacked_locm, dim=1) if len(
|
| 238 |
+
stacked_locm) > 0 else None
|
| 239 |
+
|
| 240 |
+
stacked_wref = torch.stack(stacked_wref, dim=1)
|
| 241 |
+
stacked_feat = torch.stack(stacked_feat, dim=1) if len(
|
| 242 |
+
stacked_feat) > 0 else None
|
| 243 |
+
stacked_anci = torch.stack(stacked_anci, dim=1) if len(
|
| 244 |
+
stacked_anci) > 0 else None
|
| 245 |
+
y_flow.append(stacked_flow)
|
| 246 |
+
y_feat_flow.append(stacked_feat_flow)
|
| 247 |
+
|
| 248 |
+
y_warp_x_reference.append(stacked_wref)
|
| 249 |
+
y_warp_x_reference_features.append(stacked_feat)
|
| 250 |
+
# compute normalized confidence
|
| 251 |
+
stacked_contrib = torch.nn.functional.softmax(stacked_mask, dim=1)
|
| 252 |
+
|
| 253 |
+
# torch.sum to remove temp dimension M from NMCHW --> NCHW
|
| 254 |
+
weighted_flow = torch.sum(
|
| 255 |
+
stacked_mask[:, :, 0:1, :, :] * stacked_contrib[:, :, 0:1, :, :] * stacked_flow, dim=1)
|
| 256 |
+
weighted_mask = torch.sum(
|
| 257 |
+
stacked_contrib[:, :, 0:1, :, :] * stacked_mask[:, :, 0:1, :, :], dim=1)
|
| 258 |
+
weighted_wref = torch.sum(
|
| 259 |
+
stacked_mask[:, :, 0:1, :, :] * stacked_contrib[:, :, 0:1, :, :] * stacked_wref, dim=1) if stacked_wref is not None else None
|
| 260 |
+
weighted_feat = torch.sum(
|
| 261 |
+
stacked_mask[:, :, 1:2, :, :] * stacked_contrib[:, :, 1:2, :, :] * stacked_feat, dim=1) if stacked_feat is not None else None
|
| 262 |
+
weighted_mesg = torch.sum(
|
| 263 |
+
stacked_mask[:, :, 2:, :, :] * stacked_contrib[:, :, 2:, :, :] * stacked_mesg, dim=1) if stacked_mesg is not None else None
|
| 264 |
+
y_weighted_flow.append(weighted_flow)
|
| 265 |
+
y_weighted_mask.append(weighted_mask)
|
| 266 |
+
if weighted_mesg is not None:
|
| 267 |
+
y_weighted_message.append(weighted_mesg)
|
| 268 |
+
if stacked_locm is not None:
|
| 269 |
+
y_local_message.append(stacked_locm)
|
| 270 |
+
y_weighted_message.append(weighted_mesg)
|
| 271 |
+
y_weighted_x_reference.append(weighted_wref)
|
| 272 |
+
y_weighted_x_reference_features.append(weighted_feat)
|
| 273 |
+
|
| 274 |
+
if weighted_feat is not None:
|
| 275 |
+
y_weighted_x_reference_features.append(weighted_feat)
|
| 276 |
+
return {
|
| 277 |
+
"y_last_remote_features": [weighted_mesg],
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
def flow_rescale(self, prev_flow, each_x_reference_features):
|
| 281 |
+
if prev_flow is None:
|
| 282 |
+
prev_flow = torch.zeros_like(
|
| 283 |
+
each_x_reference_features[:, :2]).to(device)
|
| 284 |
+
else:
|
| 285 |
+
up_scale_factor = each_x_reference_features.shape[-1] / \
|
| 286 |
+
prev_flow.shape[-1]
|
| 287 |
+
if up_scale_factor != 1:
|
| 288 |
+
prev_flow = F.interpolate(prev_flow, scale_factor=up_scale_factor, mode="bilinear",
|
| 289 |
+
align_corners=False, recompute_scale_factor=False) * up_scale_factor
|
| 290 |
+
return prev_flow
|
model/warplayer.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 5 |
+
backwarp_tenGrid = {}
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def warp(tenInput, tenFlow):
|
| 9 |
+
with torch.cuda.amp.autocast(enabled=False):
|
| 10 |
+
k = (str(tenFlow.device), str(tenFlow.size()))
|
| 11 |
+
if k not in backwarp_tenGrid:
|
| 12 |
+
tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device).view(
|
| 13 |
+
1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
|
| 14 |
+
tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device).view(
|
| 15 |
+
1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
|
| 16 |
+
backwarp_tenGrid[k] = torch.cat(
|
| 17 |
+
[tenHorizontal, tenVertical], 1).to(device)
|
| 18 |
+
|
| 19 |
+
tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
|
| 20 |
+
tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1)
|
| 21 |
+
|
| 22 |
+
g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
|
| 23 |
+
if tenInput.dtype != g.dtype:
|
| 24 |
+
g = g.to(tenInput.dtype)
|
| 25 |
+
return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True)
|
| 26 |
+
# "zeros" "border"
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def warp_features(inp, flow, ):
|
| 30 |
+
groups = flow.shape[1]//2 # NCHW
|
| 31 |
+
samples = inp.shape[0]
|
| 32 |
+
h = inp.shape[2]
|
| 33 |
+
w = inp.shape[3]
|
| 34 |
+
assert(flow.shape[0] == samples and flow.shape[2]
|
| 35 |
+
== h and flow.shape[3] == w)
|
| 36 |
+
chns = inp.shape[1]
|
| 37 |
+
chns_per_group = chns // groups
|
| 38 |
+
assert(flow.shape[1] % 2 == 0)
|
| 39 |
+
assert(chns % groups == 0)
|
| 40 |
+
inp = inp.contiguous().view(samples*groups, chns_per_group, h, w)
|
| 41 |
+
flow = flow.contiguous().view(samples*groups, 2, h, w)
|
| 42 |
+
feat = warp(inp, flow)
|
| 43 |
+
feat = feat.view(samples, chns, h, w)
|
| 44 |
+
return feat
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def flow2rgb(flow_map_np):
|
| 48 |
+
h, w, _ = flow_map_np.shape
|
| 49 |
+
rgb_map = np.ones((h, w, 3)).astype(np.float32)/2.0
|
| 50 |
+
normalized_flow_map = np.concatenate(
|
| 51 |
+
(flow_map_np[:, :, 0:1]/h/2.0, flow_map_np[:, :, 1:2]/w/2.0), axis=2)
|
| 52 |
+
rgb_map[:, :, 0] += normalized_flow_map[:, :, 0]
|
| 53 |
+
rgb_map[:, :, 1] -= 0.5 * \
|
| 54 |
+
(normalized_flow_map[:, :, 0] + normalized_flow_map[:, :, 1])
|
| 55 |
+
rgb_map[:, :, 2] += normalized_flow_map[:, :, 1]
|
| 56 |
+
return (rgb_map.clip(0, 1)*255.0).astype(np.uint8)
|
streamlit.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
import streamlit as st
|
| 4 |
+
import os
|
| 5 |
+
import base64
|
| 6 |
+
|
| 7 |
+
st.set_page_config(layout="wide", page_title='CoNR demo', page_icon="🪐")
|
| 8 |
+
|
| 9 |
+
st.title('CoNR demo')
|
| 10 |
+
st.markdown(""" <style>
|
| 11 |
+
#MainMenu {visibility: hidden;}
|
| 12 |
+
footer {visibility: hidden;}
|
| 13 |
+
</style> """, unsafe_allow_html=True)
|
| 14 |
+
|
| 15 |
+
def get_base64(bin_file):
|
| 16 |
+
with open(bin_file, 'rb') as f:
|
| 17 |
+
data = f.read()
|
| 18 |
+
return base64.b64encode(data).decode()
|
| 19 |
+
|
| 20 |
+
# def set_background(png_file):
|
| 21 |
+
# bin_str = get_base64(png_file)
|
| 22 |
+
# page_bg_img = '''
|
| 23 |
+
# <style>
|
| 24 |
+
# .stApp {
|
| 25 |
+
# background-image: url("data:image/png;base64,%s");
|
| 26 |
+
# background-size: 1920px 1080px;
|
| 27 |
+
# background-attachment:fixed;
|
| 28 |
+
# background-position:center;
|
| 29 |
+
# background-repeat:no-repeat;
|
| 30 |
+
# }
|
| 31 |
+
# </style>
|
| 32 |
+
# ''' % bin_str
|
| 33 |
+
# st.markdown(page_bg_img, unsafe_allow_html=True)
|
| 34 |
+
|
| 35 |
+
# set_background('ipad_bg.png')
|
| 36 |
+
|
| 37 |
+
upload_img = (st.file_uploader("输入character sheet", "png", accept_multiple_files=True))
|
| 38 |
+
|
| 39 |
+
if st.button('RUN!'):
|
| 40 |
+
if upload_img is not None:
|
| 41 |
+
for i in range(len(upload_img)):
|
| 42 |
+
with open('character_sheet/{}.png'.format(i), 'wb') as f:
|
| 43 |
+
f.write(upload_img[i].read())
|
| 44 |
+
|
| 45 |
+
st.info('努力推理中...')
|
| 46 |
+
os.system('sh infer.sh')
|
| 47 |
+
st.info('Done!')
|
| 48 |
+
video_file=open('output.mp4', 'rb')
|
| 49 |
+
video_bytes = video_file.read()
|
| 50 |
+
st.video(video_bytes, start_time=0)
|
| 51 |
+
else:
|
| 52 |
+
st.info('还没上传图片呢> <')
|
train.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import time
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
from distutils.util import strtobool
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
from torch.utils.data import DataLoader
|
| 10 |
+
from torchvision import transforms
|
| 11 |
+
from data_loader import (FileDataset,
|
| 12 |
+
RandomResizedCropWithAutoCenteringAndZeroPadding)
|
| 13 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 14 |
+
from conr import CoNR
|
| 15 |
+
|
| 16 |
+
def data_sampler(dataset, shuffle, distributed):
|
| 17 |
+
|
| 18 |
+
if distributed:
|
| 19 |
+
return torch.utils.data.distributed.DistributedSampler(dataset, shuffle=shuffle)
|
| 20 |
+
|
| 21 |
+
if shuffle:
|
| 22 |
+
return torch.utils.data.RandomSampler(dataset)
|
| 23 |
+
|
| 24 |
+
else:
|
| 25 |
+
return torch.utils.data.SequentialSampler(dataset)
|
| 26 |
+
|
| 27 |
+
def save_output(image_name, inputs_v, d_dir=".", crop=None):
|
| 28 |
+
import cv2
|
| 29 |
+
|
| 30 |
+
inputs_v = inputs_v.detach().squeeze()
|
| 31 |
+
input_np = torch.clamp(inputs_v*255, 0, 255).byte().cpu().numpy().transpose(
|
| 32 |
+
(1, 2, 0))
|
| 33 |
+
# cv2.setNumThreads(1)
|
| 34 |
+
out_render_scale = cv2.cvtColor(input_np, cv2.COLOR_RGBA2BGRA)
|
| 35 |
+
if crop is not None:
|
| 36 |
+
crop = crop.cpu().numpy()[0]
|
| 37 |
+
output_img = np.zeros((crop[0], crop[1], 4), dtype=np.uint8)
|
| 38 |
+
before_resize_scale = cv2.resize(
|
| 39 |
+
out_render_scale, (crop[5]-crop[4]+crop[8]+crop[9], crop[3]-crop[2]+crop[6]+crop[7]), interpolation=cv2.INTER_AREA) # w,h
|
| 40 |
+
output_img[crop[2]:crop[3], crop[4]:crop[5]] = before_resize_scale[crop[6]:before_resize_scale.shape[0] -
|
| 41 |
+
crop[7], crop[8]:before_resize_scale.shape[1]-crop[9]]
|
| 42 |
+
else:
|
| 43 |
+
output_img = out_render_scale
|
| 44 |
+
cv2.imwrite(d_dir+"/"+image_name.split(os.sep)[-1]+'.png',
|
| 45 |
+
output_img
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def test():
|
| 50 |
+
source_names_list = []
|
| 51 |
+
for name in sorted(os.listdir(args.test_input_person_images)):
|
| 52 |
+
thissource = os.path.join(args.test_input_person_images, name)
|
| 53 |
+
if os.path.isfile(thissource):
|
| 54 |
+
source_names_list.append(thissource)
|
| 55 |
+
if os.path.isdir(thissource):
|
| 56 |
+
print("skipping empty folder :"+thissource)
|
| 57 |
+
|
| 58 |
+
image_names_list = []
|
| 59 |
+
for name in sorted(os.listdir(args.test_input_poses_images)):
|
| 60 |
+
thistarget = os.path.join(args.test_input_poses_images, name)
|
| 61 |
+
if os.path.isfile(thistarget):
|
| 62 |
+
image_names_list.append([thistarget, *source_names_list])
|
| 63 |
+
if os.path.isdir(thistarget):
|
| 64 |
+
print("skipping folder :"+thistarget)
|
| 65 |
+
print(image_names_list)
|
| 66 |
+
|
| 67 |
+
print("---building models")
|
| 68 |
+
conrmodel = CoNR(args)
|
| 69 |
+
conrmodel.load_model(path=args.test_checkpoint_dir)
|
| 70 |
+
conrmodel.dist()
|
| 71 |
+
infer(args, conrmodel, image_names_list)
|
| 72 |
+
|
| 73 |
+
# def test():
|
| 74 |
+
# source_names_list = []
|
| 75 |
+
# for name in os.listdir(args.test_input_person_images):
|
| 76 |
+
# thissource = os.path.join(args.test_input_person_images, name)
|
| 77 |
+
# if os.path.isfile(thissource):
|
| 78 |
+
# source_names_list.append([thissource])
|
| 79 |
+
# if os.path.isdir(thissource):
|
| 80 |
+
# toadd = [os.path.join(thissource, this_file)
|
| 81 |
+
# for this_file in os.listdir(thissource)]
|
| 82 |
+
# if (toadd != []):
|
| 83 |
+
# source_names_list.append(toadd)
|
| 84 |
+
# else:
|
| 85 |
+
# print("skipping empty folder :"+thissource)
|
| 86 |
+
# image_names_list = []
|
| 87 |
+
# for eachlist in source_names_list:
|
| 88 |
+
# for name in sorted(os.listdir(args.test_input_poses_images)):
|
| 89 |
+
# thistarget = os.path.join(args.test_input_poses_images, name)
|
| 90 |
+
# if os.path.isfile(thistarget):
|
| 91 |
+
# image_names_list.append([thistarget, *eachlist])
|
| 92 |
+
# if os.path.isdir(thistarget):
|
| 93 |
+
# print("skipping folder :"+thistarget)
|
| 94 |
+
|
| 95 |
+
# print(image_names_list)
|
| 96 |
+
# print("---building models...")
|
| 97 |
+
# conrmodel = CoNR(args)
|
| 98 |
+
# conrmodel.load_model(path=args.test_checkpoint_dir)
|
| 99 |
+
# conrmodel.dist()
|
| 100 |
+
# infer(args, conrmodel, image_names_list)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def infer(args, humanflowmodel, image_names_list):
|
| 104 |
+
print("---test images: ", len(image_names_list))
|
| 105 |
+
test_salobj_dataset = FileDataset(image_names_list=image_names_list,
|
| 106 |
+
fg_img_lbl_transform=transforms.Compose([
|
| 107 |
+
RandomResizedCropWithAutoCenteringAndZeroPadding(
|
| 108 |
+
(args.dataloader_imgsize, args.dataloader_imgsize), scale=(1, 1), ratio=(1.0, 1.0), center_jitter=(0.0, 0.0)
|
| 109 |
+
)]),
|
| 110 |
+
shader_pose_use_gt_udp_test=not args.test_pose_use_parser_udp,
|
| 111 |
+
shader_target_use_gt_rgb_debug=False
|
| 112 |
+
)
|
| 113 |
+
sampler = data_sampler(test_salobj_dataset, shuffle=False,
|
| 114 |
+
distributed=args.distributed)
|
| 115 |
+
train_data = DataLoader(test_salobj_dataset,
|
| 116 |
+
batch_size=1,
|
| 117 |
+
shuffle=False,sampler=sampler,
|
| 118 |
+
num_workers=args.dataloaders)
|
| 119 |
+
|
| 120 |
+
# start testing
|
| 121 |
+
|
| 122 |
+
train_num = train_data.__len__()
|
| 123 |
+
time_stamp = time.time()
|
| 124 |
+
prev_frame_rgb = []
|
| 125 |
+
prev_frame_a = []
|
| 126 |
+
for i, data in enumerate(train_data):
|
| 127 |
+
data_time_interval = time.time() - time_stamp
|
| 128 |
+
time_stamp = time.time()
|
| 129 |
+
with torch.no_grad():
|
| 130 |
+
data["character_images"] = torch.cat(
|
| 131 |
+
[data["character_images"], *prev_frame_rgb], dim=1)
|
| 132 |
+
data["character_masks"] = torch.cat(
|
| 133 |
+
[data["character_masks"], *prev_frame_a], dim=1)
|
| 134 |
+
data = humanflowmodel.data_norm_image(data)
|
| 135 |
+
pred = humanflowmodel.model_step(data, training=False)
|
| 136 |
+
# remember to call humanflowmodel.reset_charactersheet() if you change character .
|
| 137 |
+
|
| 138 |
+
train_time_interval = time.time() - time_stamp
|
| 139 |
+
time_stamp = time.time()
|
| 140 |
+
if i % 5 == 0 and args.local_rank == 0:
|
| 141 |
+
print("[infer batch: %4d/%4d] time:%2f+%2f" % (
|
| 142 |
+
i, train_num,
|
| 143 |
+
data_time_interval, train_time_interval
|
| 144 |
+
))
|
| 145 |
+
with torch.no_grad():
|
| 146 |
+
|
| 147 |
+
if args.test_output_video:
|
| 148 |
+
pred_img = pred["shader"]["y_weighted_warp_decoded_rgba"]
|
| 149 |
+
save_output(
|
| 150 |
+
str(int(data["imidx"].cpu().item())), pred_img, args.test_output_dir, crop=data["pose_crop"])
|
| 151 |
+
|
| 152 |
+
if args.test_output_udp:
|
| 153 |
+
pred_img = pred["shader"]["x_target_sudp_a"]
|
| 154 |
+
save_output(
|
| 155 |
+
"udp_"+str(int(data["imidx"].cpu().item())), pred_img, args.test_output_dir)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def build_args():
|
| 159 |
+
parser = argparse.ArgumentParser()
|
| 160 |
+
# distributed learning settings
|
| 161 |
+
parser.add_argument("--world_size", type=int, default=1,
|
| 162 |
+
help='world size')
|
| 163 |
+
parser.add_argument("--local_rank", type=int, default=0,
|
| 164 |
+
help='local_rank, DON\'T change it')
|
| 165 |
+
|
| 166 |
+
# model settings
|
| 167 |
+
parser.add_argument('--dataloader_imgsize', type=int, default=256,
|
| 168 |
+
help='Input image size of the model')
|
| 169 |
+
parser.add_argument('--batch_size', type=int, default=4,
|
| 170 |
+
help='minibatch size')
|
| 171 |
+
parser.add_argument('--model_name', default='model_result',
|
| 172 |
+
help='Name of the experiment')
|
| 173 |
+
parser.add_argument('--dataloaders', type=int, default=2,
|
| 174 |
+
help='Num of dataloaders')
|
| 175 |
+
parser.add_argument('--mode', default="test", choices=['train', 'test'],
|
| 176 |
+
help='Training mode or Testing mode')
|
| 177 |
+
|
| 178 |
+
# i/o settings
|
| 179 |
+
parser.add_argument('--test_input_person_images',
|
| 180 |
+
type=str, default="./character_sheet/",
|
| 181 |
+
help='Directory to input character sheets')
|
| 182 |
+
parser.add_argument('--test_input_poses_images', type=str,
|
| 183 |
+
default="./test_data/",
|
| 184 |
+
help='Directory to input UDP sequences or pose images')
|
| 185 |
+
parser.add_argument('--test_checkpoint_dir', type=str,
|
| 186 |
+
default='./weights/',
|
| 187 |
+
help='Directory to model weights')
|
| 188 |
+
parser.add_argument('--test_output_dir', type=str,
|
| 189 |
+
default="./results/",
|
| 190 |
+
help='Directory to output images')
|
| 191 |
+
|
| 192 |
+
# output content settings
|
| 193 |
+
parser.add_argument('--test_output_video', type=strtobool, default=True,
|
| 194 |
+
help='Whether to output the final result of CoNR, \
|
| 195 |
+
images will be output to test_output_dir while True.')
|
| 196 |
+
parser.add_argument('--test_output_udp', type=strtobool, default=False,
|
| 197 |
+
help='Whether to output UDP generated from UDP detector, \
|
| 198 |
+
this is meaningful ONLY when test_input_poses_images \
|
| 199 |
+
is not UDP sequences but pose images. Meanwhile, \
|
| 200 |
+
test_pose_use_parser_udp need to be True')
|
| 201 |
+
|
| 202 |
+
# UDP detector settings
|
| 203 |
+
parser.add_argument('--test_pose_use_parser_udp',
|
| 204 |
+
type=strtobool, default=False,
|
| 205 |
+
help='Whether to use UDP detector to generate UDP from pngs, \
|
| 206 |
+
pose input MUST be pose images instead of UDP sequences \
|
| 207 |
+
while True')
|
| 208 |
+
|
| 209 |
+
args = parser.parse_args()
|
| 210 |
+
|
| 211 |
+
args.distributed = (args.world_size > 1)
|
| 212 |
+
if args.local_rank == 0:
|
| 213 |
+
print("batch_size:", args.batch_size, flush=True)
|
| 214 |
+
if args.distributed:
|
| 215 |
+
if args.local_rank == 0:
|
| 216 |
+
print("world_size: ", args.world_size)
|
| 217 |
+
torch.distributed.init_process_group(
|
| 218 |
+
backend="nccl", init_method="env://", world_size=args.world_size)
|
| 219 |
+
torch.cuda.set_device(args.local_rank)
|
| 220 |
+
torch.backends.cudnn.benchmark = True
|
| 221 |
+
else:
|
| 222 |
+
args.local_rank = 0
|
| 223 |
+
|
| 224 |
+
return args
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
if __name__ == "__main__":
|
| 228 |
+
args = build_args()
|
| 229 |
+
test()
|