Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import math | |
| from third_party.arcface.iresnet import iresnet50, iresnet100 | |
| class MouthNet(nn.Module): | |
| def __init__(self, | |
| bisenet: nn.Module, | |
| feature_dim: int = 64, | |
| crop_param: tuple = (0, 0, 112, 112), | |
| iresnet_pretrained: bool = False, | |
| ): | |
| super(MouthNet, self).__init__() | |
| crop_size = (crop_param[3] - crop_param[1], crop_param[2] - crop_param[0]) # (H,W) | |
| fc_scale = int(math.ceil(crop_size[0] / 112 * 7) * math.ceil(crop_size[1] / 112 * 7)) | |
| self.bisenet = bisenet | |
| self.backbone = iresnet50( | |
| pretrained=iresnet_pretrained, | |
| num_features=feature_dim, | |
| fp16=False, | |
| fc_scale=fc_scale, | |
| ) | |
| self.register_buffer( | |
| name="vgg_mean", | |
| tensor=torch.tensor([[[0.485]], [[0.456]], [[0.406]]], requires_grad=False), | |
| ) | |
| self.register_buffer( | |
| name="vgg_std", | |
| tensor=torch.tensor([[[0.229]], [[0.224]], [[0.225]]], requires_grad=False), | |
| ) | |
| def forward(self, x): | |
| # with torch.no_grad(): | |
| # x_mouth_mask = self.get_any_mask(x, par=[11, 12, 13], normalized=True) # (B,1,H,W), in [0,1], 1:chosed | |
| x_mouth_mask = 1 | |
| x_mouth = x * x_mouth_mask # (B,3,112,112) | |
| mouth_feature = self.backbone(x_mouth) | |
| return mouth_feature | |
| def get_any_mask(self, img, par, normalized=False): | |
| # [0, 'background', 1 'skin', 2 'l_brow', 3 'r_brow', 4 'l_eye', 5 'r_eye', | |
| # 6 'eye_g', 7 'l_ear', 8 'r_ear', 9 'ear_r', 10 'nose', 11 'mouth', 12 'u_lip', | |
| # 13 'l_lip', 14 'neck', 15 'neck_l', 16 'cloth', 17 'hair', 18 'hat'] | |
| ori_size = img.size()[-1] | |
| with torch.no_grad(): | |
| img = F.interpolate(img, size=512, mode="nearest", ) | |
| if not normalized: | |
| img = img * 0.5 + 0.5 | |
| img = img.sub(self.vgg_mean.detach()).div(self.vgg_std.detach()) | |
| out = self.bisenet(img)[0] | |
| parsing = out.softmax(1).argmax(1) | |
| mask = torch.zeros_like(parsing) | |
| for p in par: | |
| mask = mask + ((parsing == p).float()) | |
| mask = mask.unsqueeze(1) | |
| mask = F.interpolate(mask, size=ori_size, mode="bilinear", align_corners=True) | |
| return mask | |
| def save_backbone(self, path: str): | |
| torch.save(self.backbone.state_dict(), path) | |
| def load_backbone(self, path: str): | |
| self.backbone.load_state_dict(torch.load(path, map_location='cpu')) | |
| if __name__ == "__main__": | |
| from third_party.bisenet.bisenet import BiSeNet | |
| bisenet = BiSeNet(19) | |
| bisenet.load_state_dict( | |
| torch.load( | |
| "/gavin/datasets/hanbang/79999_iter.pth", | |
| map_location="cpu", | |
| ) | |
| ) | |
| bisenet.eval() | |
| bisenet.requires_grad_(False) | |
| crop_param = (28, 56, 84, 112) | |
| import numpy as np | |
| img = np.random.randn(112, 112, 3) * 225 | |
| from PIL import Image | |
| img = Image.fromarray(img.astype(np.uint8)) | |
| img = img.crop(crop_param) | |
| from torchvision import transforms | |
| trans = transforms.ToTensor() | |
| img = trans(img).unsqueeze(0) | |
| img = img.repeat(3, 1, 1, 1) | |
| print(img.shape) | |
| net = MouthNet( | |
| bisenet=bisenet, | |
| feature_dim=64, | |
| crop_param=crop_param | |
| ) | |
| mouth_feat = net(img) | |
| print(mouth_feat.shape) | |
| import thop | |
| crop_size = (crop_param[3] - crop_param[1], crop_param[2] - crop_param[0]) # (H,W) | |
| fc_scale = int(math.ceil(crop_size[0] / 112 * 7) * math.ceil(crop_size[1] / 112 * 7)) | |
| backbone = iresnet100( | |
| pretrained=False, | |
| num_features=64, | |
| fp16=False, | |
| # fc_scale=fc_scale, | |
| ) | |
| flops, params = thop.profile(backbone, inputs=(torch.randn(1, 3, 112, 112),), verbose=False) | |
| print('#Params=%.2fM, GFLOPS=%.2f' % (params / 1e6, flops / 1e9)) | |