File size: 873 Bytes
5ab5cab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch.nn as nn
import torch.nn.functional as F
from auto_encoder.models.decoder import Decoder
from auto_encoder.models.encoder import Encoder
import yaml

class AutoEncoder(nn.Module):
    def __init__(self, config_path : str):
        super().__init__()
        with open(config_path, "r") as file:
            config = yaml.safe_load(file)
        self.add_module('encoder', Encoder(**config["encoder"]))
        self.add_module('decoder', Decoder(**config["decoder"]))
        
    def encode(self, x):
        h = self.encoder(x)
        return h
        
    def decode(self, z):
        z = self.decoder(z)
        return z
    
    def reconstruct(self, x):
        return self.decode(self.encode(x))
    
    def loss(self, x):
        x_hat = self(x)
        return F.mse_loss(x, x_hat)
        
    def forward(self, x):
        return self.reconstruct(x)