File size: 958 Bytes
3e165b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from torch import nn

from visualizr.networks.encoder import Encoder
from visualizr.networks.styledecoder import Synthesis


class Generator(nn.Module):
    def __init__(
        self,
        size,
        style_dim=512,
        motion_dim=20,
        channel_multiplier=1,
        blur_kernel=[1, 3, 3, 1],
    ):
        super(Generator, self).__init__()

        # encoder
        self.enc = Encoder(size, style_dim, motion_dim)
        self.dec = Synthesis(
            size, style_dim, motion_dim, blur_kernel, channel_multiplier
        )

    def get_direction(self):
        return self.dec.direction(None)

    def synthesis(self, wa, alpha, feat):
        img = self.dec(wa, alpha, feat)

        return img

    def forward(self, img_source, img_drive, h_start=None):
        wa, alpha, feats = self.enc(img_source, img_drive, h_start)
        # import pdb;pdb.set_trace()
        img_recon = self.dec(wa, alpha, feats)

        return img_recon