Spaces:
Running
Running
File size: 873 Bytes
5ab5cab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
import torch.nn as nn
import torch.nn.functional as F
from auto_encoder.models.decoder import Decoder
from auto_encoder.models.encoder import Encoder
import yaml
class AutoEncoder(nn.Module):
def __init__(self, config_path : str):
super().__init__()
with open(config_path, "r") as file:
config = yaml.safe_load(file)
self.add_module('encoder', Encoder(**config["encoder"]))
self.add_module('decoder', Decoder(**config["decoder"]))
def encode(self, x):
h = self.encoder(x)
return h
def decode(self, z):
z = self.decoder(z)
return z
def reconstruct(self, x):
return self.decode(self.encode(x))
def loss(self, x):
x_hat = self(x)
return F.mse_loss(x, x_hat)
def forward(self, x):
return self.reconstruct(x) |