Spaces:
Sleeping
Sleeping
"""vae.py module.""" | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from typing import List | |
class VAE(nn.Module): | |
def __init__(self, input_dim: int, latent_dim: int, hidden_dims: List[int]): | |
super(VAE, self).__init__() | |
# Encoder | |
modules = [] | |
in_features = input_dim | |
for h_dim in hidden_dims: | |
modules.append(nn.Linear(in_features, h_dim)) | |
modules.append(nn.ReLU()) | |
in_features = h_dim | |
self.encoder = nn.Sequential(*modules) | |
# Latent space | |
self.fc_mu = nn.Linear(hidden_dims[-1], latent_dim) | |
self.fc_var = nn.Linear(hidden_dims[-1], latent_dim) | |
# Decoder | |
modules = [] | |
hidden_dims.reverse() | |
in_features = latent_dim | |
for h_dim in hidden_dims: | |
modules.append(nn.Linear(in_features, h_dim)) | |
modules.append(nn.ReLU()) | |
in_features = h_dim | |
modules.append(nn.Linear(hidden_dims[-1], input_dim)) | |
self.decoder = nn.Sequential(*modules) | |
def encode(self, x): | |
h = self.encoder(x) | |
return self.fc_mu(h), self.fc_var(h) | |
def decode(self, z): | |
return self.decoder(z) | |
def reparameterize(self, mu, log_var): | |
std = torch.exp(0.5 * log_var) | |
eps = torch.randn_like(std) | |
return mu + eps * std | |
def forward(self, x): | |
mu, log_var = self.encode(x) | |
z = self.reparameterize(mu, log_var) | |
return self.decode(z), mu, log_var | |