debojit01 commited on
Commit
388fa58
·
verified ·
1 Parent(s): 97088f2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -26
app.py CHANGED
@@ -1,21 +1,24 @@
1
  import torch
2
  import numpy as np
3
- import matplotlib.pyplot as plt
4
  import gradio as gr
 
5
 
6
- # Define the same VAE architecture used during training
7
- class VAE(torch.nn.Module):
8
- def __init__(self):
9
  super().__init__()
 
 
 
10
  self.encoder = torch.nn.Sequential(
11
- torch.nn.Flatten(),
12
- torch.nn.Linear(28*28, 400),
13
  torch.nn.ReLU(),
14
  )
15
- self.mu = torch.nn.Linear(400, 20)
16
- self.logvar = torch.nn.Linear(400, 20)
 
17
  self.decoder = torch.nn.Sequential(
18
- torch.nn.Linear(20, 400),
19
  torch.nn.ReLU(),
20
  torch.nn.Linear(400, 28*28),
21
  torch.nn.Sigmoid()
@@ -26,34 +29,34 @@ class VAE(torch.nn.Module):
26
  eps = torch.randn_like(std)
27
  return mu + eps * std
28
 
29
- def forward(self, x):
30
- h = self.encoder(x)
31
- mu, logvar = self.mu(h), self.logvar(h)
32
- z = self.reparameterize(mu, logvar)
33
- return self.decoder(z)
34
 
35
- # Load model
36
- model = VAE()
37
  model.load_state_dict(torch.load("cvae_mnist.pth", map_location='cpu'))
38
  model.eval()
39
 
40
- # Generation function for Gradio
41
- def generate_images(digit):
42
- # For VAE, we ignore the digit and generate random samples
43
  images = []
44
  for _ in range(5):
45
  z = torch.randn(1, 20)
46
- img = model.decoder(z).detach().numpy().reshape(28, 28)
 
 
 
47
  images.append((img * 255).astype(np.uint8))
48
  return images
49
 
50
- # Gradio interface
51
  iface = gr.Interface(
52
- fn=generate_images,
53
- inputs=gr.Dropdown(choices=[str(i) for i in range(10)], label="Choose a digit (ignored for now)"),
54
  outputs=[gr.Image(image_mode='L') for _ in range(5)],
55
- title="Handwritten Digit Generator",
56
- description="Select a digit (0–9) and generate 5 handwritten-style digits using a VAE trained on MNIST."
57
  )
58
 
59
- iface.launch()
 
1
  import torch
2
  import numpy as np
 
3
  import gradio as gr
4
+ import matplotlib.pyplot as plt
5
 
6
+ # Conditional VAE definition (same as training)
7
+ class CVAE(torch.nn.Module):
8
+ def __init__(self, latent_dim=20):
9
  super().__init__()
10
+ self.latent_dim = latent_dim
11
+ self.label_embed = torch.nn.Embedding(10, 10)
12
+
13
  self.encoder = torch.nn.Sequential(
14
+ torch.nn.Linear(28*28 + 10, 400),
 
15
  torch.nn.ReLU(),
16
  )
17
+ self.fc_mu = torch.nn.Linear(400, latent_dim)
18
+ self.fc_logvar = torch.nn.Linear(400, latent_dim)
19
+
20
  self.decoder = torch.nn.Sequential(
21
+ torch.nn.Linear(latent_dim + 10, 400),
22
  torch.nn.ReLU(),
23
  torch.nn.Linear(400, 28*28),
24
  torch.nn.Sigmoid()
 
29
  eps = torch.randn_like(std)
30
  return mu + eps * std
31
 
32
+ def decode(self, z, y):
33
+ y_embed = self.label_embed(y)
34
+ inputs = torch.cat([z, y_embed], dim=1)
35
+ return self.decoder(inputs)
 
36
 
37
+ model = CVAE()
 
38
  model.load_state_dict(torch.load("cvae_mnist.pth", map_location='cpu'))
39
  model.eval()
40
 
41
+ # Image generation function
42
+ def generate_digit_images(digit):
 
43
  images = []
44
  for _ in range(5):
45
  z = torch.randn(1, 20)
46
+ y = torch.tensor([int(digit)])
47
+ with torch.no_grad():
48
+ out = model.decode(z, y)
49
+ img = out.view(28, 28).numpy()
50
  images.append((img * 255).astype(np.uint8))
51
  return images
52
 
53
+ # Launch Gradio app
54
  iface = gr.Interface(
55
+ fn=generate_digit_images,
56
+ inputs=gr.Dropdown(choices=[str(i) for i in range(10)], label="Choose a digit (0–9)"),
57
  outputs=[gr.Image(image_mode='L') for _ in range(5)],
58
+ title="Conditional VAE Handwritten Digit Generator",
59
+ description="Generates 5 images of the digit you select (0–9) using a Conditional Variational Autoencoder trained on MNIST."
60
  )
61
 
62
+ iface.launch()