|
import torch |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import gradio as gr |
|
|
|
|
|
class VAE(torch.nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.encoder = torch.nn.Sequential( |
|
torch.nn.Flatten(), |
|
torch.nn.Linear(28*28, 400), |
|
torch.nn.ReLU(), |
|
) |
|
self.mu = torch.nn.Linear(400, 20) |
|
self.logvar = torch.nn.Linear(400, 20) |
|
self.decoder = torch.nn.Sequential( |
|
torch.nn.Linear(20, 400), |
|
torch.nn.ReLU(), |
|
torch.nn.Linear(400, 28*28), |
|
torch.nn.Sigmoid() |
|
) |
|
|
|
def reparameterize(self, mu, logvar): |
|
std = torch.exp(0.5 * logvar) |
|
eps = torch.randn_like(std) |
|
return mu + eps * std |
|
|
|
def forward(self, x): |
|
h = self.encoder(x) |
|
mu, logvar = self.mu(h), self.logvar(h) |
|
z = self.reparameterize(mu, logvar) |
|
return self.decoder(z) |
|
|
|
|
|
model = VAE() |
|
model.load_state_dict(torch.load("vae_mnist.pth", map_location='cpu')) |
|
model.eval() |
|
|
|
|
|
def generate_images(digit): |
|
|
|
images = [] |
|
for _ in range(5): |
|
z = torch.randn(1, 20) |
|
img = model.decoder(z).detach().numpy().reshape(28, 28) |
|
images.append((img * 255).astype(np.uint8)) |
|
return images |
|
|
|
|
|
iface = gr.Interface( |
|
fn=generate_images, |
|
inputs=gr.Dropdown(choices=[str(i) for i in range(10)], label="Choose a digit (ignored for now)"), |
|
outputs=[gr.Image(shape=(28,28), image_mode='L') for _ in range(5)], |
|
title="Handwritten Digit Generator", |
|
description="Select a digit (0–9) and generate 5 handwritten-style digits using a VAE trained on MNIST." |
|
) |
|
|
|
iface.launch() |
|
|