File size: 4,938 Bytes
f670afc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
from torch import nn

from imaginaire.discriminators.multires_patch import \
    WeightSharedMultiResPatchDiscriminator
from imaginaire.discriminators.residual import ResDiscriminator


class Discriminator(nn.Module):
    r"""UNIT discriminator. It can be either a multi-resolution patch
    discriminator like in the original implementation, or a
    global residual discriminator.

    Args:
        dis_cfg (obj): Discriminator definition part of the yaml config file.
        data_cfg (obj): Data definition part of the yaml config file
    """

    def __init__(self, dis_cfg, data_cfg):
        super().__init__()
        if getattr(dis_cfg, 'patch_dis', True):
            # Use the multi-resolution patch discriminator. It works better for
            # scene images and when you want to preserve pixel-wise
            # correspondence during translation.
            self.discriminator_a = \
                WeightSharedMultiResPatchDiscriminator(**vars(dis_cfg))
            self.discriminator_b = \
                WeightSharedMultiResPatchDiscriminator(**vars(dis_cfg))
        else:
            # Use the global residual discriminator. It works better if images
            # have a single centered object (e.g., animal faces, shoes).
            self.discriminator_a = ResDiscriminator(**vars(dis_cfg))
            self.discriminator_b = ResDiscriminator(**vars(dis_cfg))

    def forward(self, data, net_G_output, gan_recon=False, real=True):
        r"""Returns the output of the discriminator.

        Args:
            data (dict):
              - images_a  (tensor) : Images in domain A.
              - images_b  (tensor) : Images in domain B.
            net_G_output (dict):
              - images_ab  (tensor) : Images translated from domain A to B by
                the generator.
              - images_ba  (tensor) : Images translated from domain B to A by
                the generator.
              - images_aa  (tensor) : Reconstructed images in domain A.
              - images_bb  (tensor) : Reconstructed images in domain B.
            gan_recon (bool): If ``True``, also classifies reconstructed images.
            real (bool): If ``True``, also classifies real images. Otherwise it
                only classifies generated images to save computation during the
                generator update.
        Returns:
            (dict):
              - out_ab (tensor): Output of the discriminator for images
                translated from domain A to B by the generator.
              - out_ab (tensor): Output of the discriminator for images
                translated from domain B to A by the generator.
              - fea_ab (tensor): Intermediate features of the discriminator
                for images translated from domain B to A by the generator.
              - fea_ba (tensor): Intermediate features of the discriminator
                for images translated from domain A to B by the generator.

              - out_a (tensor): Output of the discriminator for images
                in domain A.
              - out_b (tensor): Output of the discriminator for images
                in domain B.
              - fea_a (tensor): Intermediate features of the discriminator
                for images in domain A.
              - fea_b (tensor): Intermediate features of the discriminator
                for images in domain B.

              - out_aa (tensor): Output of the discriminator for
                reconstructed images in domain A.
              - out_bb (tensor): Output of the discriminator for
                reconstructed images in domain B.
              - fea_aa (tensor): Intermediate features of the discriminator
                for reconstructed images in domain A.
              - fea_bb (tensor): Intermediate features of the discriminator
                for reconstructed images in domain B.
        """
        out_ab, fea_ab, _ = self.discriminator_b(net_G_output['images_ab'])
        out_ba, fea_ba, _ = self.discriminator_a(net_G_output['images_ba'])
        output = dict(out_ba=out_ba, out_ab=out_ab,
                      fea_ba=fea_ba, fea_ab=fea_ab)
        if real:
            out_a, fea_a, _ = self.discriminator_a(data['images_a'])
            out_b, fea_b, _ = self.discriminator_b(data['images_b'])
            output.update(dict(out_a=out_a, out_b=out_b,
                               fea_a=fea_a, fea_b=fea_b))
        if gan_recon:
            out_aa, fea_aa, _ = self.discriminator_a(net_G_output['images_aa'])
            out_bb, fea_bb, _ = self.discriminator_b(net_G_output['images_bb'])
            output.update(dict(out_aa=out_aa, out_bb=out_bb,
                               fea_aa=fea_aa, fea_bb=fea_bb))
        return output