Spaces:
Running
Running
| import torch | |
| import torchvision | |
| import torch.nn as nn | |
| from .conv_autoencoder import ConvEncoder, DeconvDecoder, INRDecoder | |
| from .ops import ScaleLayer | |
| class IHModelWithBackbone(nn.Module): | |
| def __init__( | |
| self, | |
| model, backbone, | |
| downsize_backbone_input=False, | |
| mask_fusion='sum', | |
| backbone_conv1_channels=64, opt=None | |
| ): | |
| super(IHModelWithBackbone, self).__init__() | |
| self.downsize_backbone_input = downsize_backbone_input | |
| self.mask_fusion = mask_fusion | |
| self.backbone = backbone | |
| self.model = model | |
| self.opt = opt | |
| self.mask_conv = nn.Sequential( | |
| nn.Conv2d(1, backbone_conv1_channels, kernel_size=3, stride=2, padding=1, bias=True), | |
| ScaleLayer(init_value=0.1, lr_mult=1) | |
| ) | |
| def forward(self, image, mask, coord=None, start_proportion=None): | |
| if self.opt.INRDecode and self.opt.hr_train and (self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution')): | |
| backbone_image = torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(image[0]) | |
| backbone_mask = torch.cat( | |
| (torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask[0]), | |
| 1.0 - torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask[0])), dim=1) | |
| else: | |
| backbone_image = torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(image) | |
| backbone_mask = torch.cat((torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask), | |
| 1.0 - torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask)), dim=1) | |
| backbone_mask_features = self.mask_conv(backbone_mask[:, :1]) | |
| backbone_features = self.backbone(backbone_image, backbone_mask, backbone_mask_features) | |
| output = self.model(image, mask, backbone_features, coord=coord, start_proportion=start_proportion) | |
| return output | |
| class DeepImageHarmonization(nn.Module): | |
| def __init__( | |
| self, | |
| depth, | |
| norm_layer=nn.BatchNorm2d, batchnorm_from=0, | |
| attend_from=-1, | |
| image_fusion=False, | |
| ch=64, max_channels=512, | |
| backbone_from=-1, backbone_channels=None, backbone_mode='', opt=None | |
| ): | |
| super(DeepImageHarmonization, self).__init__() | |
| self.depth = depth | |
| self.encoder = ConvEncoder( | |
| depth, ch, | |
| norm_layer, batchnorm_from, max_channels, | |
| backbone_from, backbone_channels, backbone_mode, INRDecode=opt.INRDecode | |
| ) | |
| self.opt = opt | |
| if opt.INRDecode: | |
| "See Table 2 in the paper to test with different INR decoders' structures." | |
| self.decoder = INRDecoder(depth, self.encoder.blocks_channels, norm_layer, opt, backbone_from) | |
| else: | |
| "Baseline: https://github.com/SamsungLabs/image_harmonization" | |
| self.decoder = DeconvDecoder(depth, self.encoder.blocks_channels, norm_layer, attend_from, image_fusion) | |
| def forward(self, image, mask, backbone_features=None, coord=None, start_proportion=None): | |
| if self.opt.INRDecode and self.opt.hr_train and (self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution')): | |
| x = torch.cat((torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(image[0]), | |
| torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask[0])), dim=1) | |
| else: | |
| x = torch.cat((torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(image), | |
| torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask)), dim=1) | |
| intermediates = self.encoder(x, backbone_features) | |
| if self.opt.INRDecode and self.opt.hr_train and (self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution')): | |
| output = self.decoder(intermediates, image[1], mask[1], coord_samples=coord, start_proportion=start_proportion) | |
| else: | |
| output = self.decoder(intermediates, image, mask) | |
| return output | |