File size: 1,568 Bytes
5ab5cab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import torch
import torch.nn as nn
import torch.nn.functional as F

from auto_encoder.models.encoder import Encoder
from auto_encoder.models.decoder import Decoder
import yaml
from auto_encoder.components.distributions import DiagonalGaussianDistribution

class VariationalAutoEncoder(nn.Module):
    def __init__(self, config_path):
        super().__init__()
        with open(config_path, "r") as file:
            config = yaml.safe_load(file)
        self.add_module('encoder', Encoder(**config["encoder"]))
        self.add_module('decoder', Decoder(**config["decoder"]))
        self.embed_dim = config['vae']['embed_dim']
        self.kld_weight = float(config['vae']['kld_weight'])
        
        self.quant_conv = torch.nn.Conv2d(self.decoder.z_channels, 2*self.embed_dim, 1)
        self.post_quant_conv = torch.nn.Conv2d(self.embed_dim, self.decoder.z_channels, 1)

    def encode(self, x):
        h = self.encoder(x)
        moments = self.quant_conv(h)
        posterior = DiagonalGaussianDistribution(moments)
        return posterior

    def decode(self, z):
        z = self.post_quant_conv(z)
        dec = self.decoder(z)
        return dec
    
    def loss(self, x):
        x_hat, posterior = self(x)
        return F.mse_loss(x, x_hat) + self.kld_weight * posterior.kl().mean() 

    def forward(self, input, sample_posterior=True):
        posterior = self.encode(input)
        if sample_posterior:
            z = posterior.sample()
        else:
            z = posterior.mode()
        dec = self.decode(z)
        return dec, posterior