File size: 5,541 Bytes
f670afc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
import functools

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from imaginaire.layers import Conv2dBlock


class FPSEDiscriminator(nn.Module):
    r"""# Feature-Pyramid Semantics Embedding Discriminator. This is a copy
    of the discriminator in https://arxiv.org/pdf/1910.06809.pdf
    """

    def __init__(self,
                 num_input_channels,
                 num_labels,
                 num_filters,
                 kernel_size,
                 weight_norm_type,
                 activation_norm_type):
        super().__init__()
        padding = int(np.ceil((kernel_size - 1.0) / 2))
        nonlinearity = 'leakyrelu'
        stride1_conv2d_block = \
            functools.partial(Conv2dBlock,
                              kernel_size=kernel_size,
                              stride=1,
                              padding=padding,
                              weight_norm_type=weight_norm_type,
                              activation_norm_type=activation_norm_type,
                              nonlinearity=nonlinearity,
                              # inplace_nonlinearity=True,
                              order='CNA')
        down_conv2d_block = \
            functools.partial(Conv2dBlock,
                              kernel_size=kernel_size,
                              stride=2,
                              padding=padding,
                              weight_norm_type=weight_norm_type,
                              activation_norm_type=activation_norm_type,
                              nonlinearity=nonlinearity,
                              # inplace_nonlinearity=True,
                              order='CNA')
        latent_conv2d_block = \
            functools.partial(Conv2dBlock,
                              kernel_size=1,
                              stride=1,
                              weight_norm_type=weight_norm_type,
                              activation_norm_type=activation_norm_type,
                              nonlinearity=nonlinearity,
                              # inplace_nonlinearity=True,
                              order='CNA')
        # bottom-up pathway

        self.enc1 = down_conv2d_block(num_input_channels, num_filters)
        self.enc2 = down_conv2d_block(1 * num_filters, 2 * num_filters)
        self.enc3 = down_conv2d_block(2 * num_filters, 4 * num_filters)
        self.enc4 = down_conv2d_block(4 * num_filters, 8 * num_filters)
        self.enc5 = down_conv2d_block(8 * num_filters, 8 * num_filters)

        # top-down pathway
        self.lat2 = latent_conv2d_block(2 * num_filters, 4 * num_filters)
        self.lat3 = latent_conv2d_block(4 * num_filters, 4 * num_filters)
        self.lat4 = latent_conv2d_block(8 * num_filters, 4 * num_filters)
        self.lat5 = latent_conv2d_block(8 * num_filters, 4 * num_filters)

        # upsampling
        self.upsample2x = nn.Upsample(scale_factor=2, mode='bilinear',
                                      align_corners=False)

        # final layers
        self.final2 = stride1_conv2d_block(4 * num_filters, 2 * num_filters)
        self.final3 = stride1_conv2d_block(4 * num_filters, 2 * num_filters)
        self.final4 = stride1_conv2d_block(4 * num_filters, 2 * num_filters)

        # true/false prediction and semantic alignment prediction
        self.output = Conv2dBlock(num_filters * 2, 1, kernel_size=1)
        self.seg = Conv2dBlock(num_filters * 2, num_filters * 2, kernel_size=1)
        self.embedding = Conv2dBlock(num_labels, num_filters * 2, kernel_size=1)

    def forward(self, images, segmaps):
        r"""

        Args:
            images: image tensors.
            segmaps: segmentation map tensors.
        """
        # bottom-up pathway
        feat11 = self.enc1(images)
        feat12 = self.enc2(feat11)
        feat13 = self.enc3(feat12)
        feat14 = self.enc4(feat13)
        feat15 = self.enc5(feat14)
        # top-down pathway and lateral connections
        feat25 = self.lat5(feat15)
        feat24 = self.upsample2x(feat25) + self.lat4(feat14)
        feat23 = self.upsample2x(feat24) + self.lat3(feat13)
        feat22 = self.upsample2x(feat23) + self.lat2(feat12)
        # final prediction layers
        feat32 = self.final2(feat22)
        feat33 = self.final3(feat23)
        feat34 = self.final4(feat24)
        # Patch-based True/False prediction
        pred2 = self.output(feat32)
        pred3 = self.output(feat33)
        pred4 = self.output(feat34)
        seg2 = self.seg(feat32)
        seg3 = self.seg(feat33)
        seg4 = self.seg(feat34)

        # # segmentation map embedding
        segembs = self.embedding(segmaps)
        segembs = F.avg_pool2d(segembs, kernel_size=2, stride=2)
        segembs2 = F.avg_pool2d(segembs, kernel_size=2, stride=2)
        segembs3 = F.avg_pool2d(segembs2, kernel_size=2, stride=2)
        segembs4 = F.avg_pool2d(segembs3, kernel_size=2, stride=2)

        # semantics embedding discriminator score
        pred2 += torch.mul(segembs2, seg2).sum(dim=1, keepdim=True)
        pred3 += torch.mul(segembs3, seg3).sum(dim=1, keepdim=True)
        pred4 += torch.mul(segembs4, seg4).sum(dim=1, keepdim=True)

        # concat results from multiple resolutions
        # results = [pred2, pred3, pred4]

        return pred2, pred3, pred4