import torch import numpy as np import matplotlib.pyplot as plt import gradio as gr # Define the same VAE architecture used during training class VAE(torch.nn.Module): def __init__(self): super().__init__() self.encoder = torch.nn.Sequential( torch.nn.Flatten(), torch.nn.Linear(28*28, 400), torch.nn.ReLU(), ) self.mu = torch.nn.Linear(400, 20) self.logvar = torch.nn.Linear(400, 20) self.decoder = torch.nn.Sequential( torch.nn.Linear(20, 400), torch.nn.ReLU(), torch.nn.Linear(400, 28*28), torch.nn.Sigmoid() ) def reparameterize(self, mu, logvar): std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return mu + eps * std def forward(self, x): h = self.encoder(x) mu, logvar = self.mu(h), self.logvar(h) z = self.reparameterize(mu, logvar) return self.decoder(z) # Load model model = VAE() model.load_state_dict(torch.load("cvae_mnist.pth", map_location='cpu')) model.eval() # Generation function for Gradio def generate_images(digit): # For VAE, we ignore the digit and generate random samples images = [] for _ in range(5): z = torch.randn(1, 20) img = model.decoder(z).detach().numpy().reshape(28, 28) images.append((img * 255).astype(np.uint8)) return images # Gradio interface iface = gr.Interface( fn=generate_images, inputs=gr.Dropdown(choices=[str(i) for i in range(10)], label="Choose a digit (ignored for now)"), outputs=[gr.Image(image_mode='L') for _ in range(5)], title="Handwritten Digit Generator", description="Select a digit (0–9) and generate 5 handwritten-style digits using a VAE trained on MNIST." ) iface.launch()