File size: 1,837 Bytes
4ad8a34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
import numpy as np
import matplotlib.pyplot as plt
import gradio as gr

# Define the same VAE architecture used during training
class VAE(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = torch.nn.Sequential(
            torch.nn.Flatten(),
            torch.nn.Linear(28*28, 400),
            torch.nn.ReLU(),
        )
        self.mu = torch.nn.Linear(400, 20)
        self.logvar = torch.nn.Linear(400, 20)
        self.decoder = torch.nn.Sequential(
            torch.nn.Linear(20, 400),
            torch.nn.ReLU(),
            torch.nn.Linear(400, 28*28),
            torch.nn.Sigmoid()
        )

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        h = self.encoder(x)
        mu, logvar = self.mu(h), self.logvar(h)
        z = self.reparameterize(mu, logvar)
        return self.decoder(z)

# Load model
model = VAE()
model.load_state_dict(torch.load("vae_mnist.pth", map_location='cpu'))
model.eval()

# Generation function for Gradio
def generate_images(digit):
    # For VAE, we ignore the digit and generate random samples
    images = []
    for _ in range(5):
        z = torch.randn(1, 20)
        img = model.decoder(z).detach().numpy().reshape(28, 28)
        images.append((img * 255).astype(np.uint8))
    return images

# Gradio interface
iface = gr.Interface(
    fn=generate_images,
    inputs=gr.Dropdown(choices=[str(i) for i in range(10)], label="Choose a digit (ignored for now)"),
    outputs=[gr.Image(shape=(28,28), image_mode='L') for _ in range(5)],
    title="Handwritten Digit Generator",
    description="Select a digit (0–9) and generate 5 handwritten-style digits using a VAE trained on MNIST."
)

iface.launch()