debojit01 commited on
Commit
4ad8a34
·
verified ·
1 Parent(s): b11925a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -0
app.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import gradio as gr
5
+
6
+ # Define the same VAE architecture used during training
7
+ class VAE(torch.nn.Module):
8
+ def __init__(self):
9
+ super().__init__()
10
+ self.encoder = torch.nn.Sequential(
11
+ torch.nn.Flatten(),
12
+ torch.nn.Linear(28*28, 400),
13
+ torch.nn.ReLU(),
14
+ )
15
+ self.mu = torch.nn.Linear(400, 20)
16
+ self.logvar = torch.nn.Linear(400, 20)
17
+ self.decoder = torch.nn.Sequential(
18
+ torch.nn.Linear(20, 400),
19
+ torch.nn.ReLU(),
20
+ torch.nn.Linear(400, 28*28),
21
+ torch.nn.Sigmoid()
22
+ )
23
+
24
+ def reparameterize(self, mu, logvar):
25
+ std = torch.exp(0.5 * logvar)
26
+ eps = torch.randn_like(std)
27
+ return mu + eps * std
28
+
29
+ def forward(self, x):
30
+ h = self.encoder(x)
31
+ mu, logvar = self.mu(h), self.logvar(h)
32
+ z = self.reparameterize(mu, logvar)
33
+ return self.decoder(z)
34
+
35
+ # Load model
36
+ model = VAE()
37
+ model.load_state_dict(torch.load("vae_mnist.pth", map_location='cpu'))
38
+ model.eval()
39
+
40
+ # Generation function for Gradio
41
+ def generate_images(digit):
42
+ # For VAE, we ignore the digit and generate random samples
43
+ images = []
44
+ for _ in range(5):
45
+ z = torch.randn(1, 20)
46
+ img = model.decoder(z).detach().numpy().reshape(28, 28)
47
+ images.append((img * 255).astype(np.uint8))
48
+ return images
49
+
50
+ # Gradio interface
51
+ iface = gr.Interface(
52
+ fn=generate_images,
53
+ inputs=gr.Dropdown(choices=[str(i) for i in range(10)], label="Choose a digit (ignored for now)"),
54
+ outputs=[gr.Image(shape=(28,28), image_mode='L') for _ in range(5)],
55
+ title="Handwritten Digit Generator",
56
+ description="Select a digit (0–9) and generate 5 handwritten-style digits using a VAE trained on MNIST."
57
+ )
58
+
59
+ iface.launch()