Spaces:
Running
Running
import streamlit as st | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torchvision | |
import torchvision.transforms as transforms | |
from torchvision.utils import make_grid | |
import matplotlib.pyplot as plt | |
# Set device | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Hyperparameters | |
z_dim = 64 | |
image_dim = 28 * 28 | |
batch_size = 32 | |
lr = 3e-4 | |
# Load Data | |
transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize((0.5,), (0.5,)) | |
]) | |
dataset = torchvision.datasets.MNIST(root='dataset/', transform=transform, download=True) | |
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True) | |
# Generator | |
class Generator(nn.Module): | |
def __init__(self, z_dim, img_dim): | |
super().__init__() | |
self.gen = nn.Sequential( | |
nn.Linear(z_dim, 256), | |
nn.ReLU(), | |
nn.Linear(256, 512), | |
nn.ReLU(), | |
nn.Linear(512, 1024), | |
nn.ReLU(), | |
nn.Linear(1024, img_dim), | |
nn.Tanh() | |
) | |
def forward(self, x): | |
return self.gen(x) | |
# Discriminator | |
class Discriminator(nn.Module): | |
def __init__(self, img_dim): | |
super().__init__() | |
self.disc = nn.Sequential( | |
nn.Linear(img_dim, 1024), | |
nn.ReLU(), | |
nn.Linear(1024, 512), | |
nn.ReLU(), | |
nn.Linear(512, 256), | |
nn.ReLU(), | |
nn.Linear(256, 1), | |
nn.Sigmoid(), | |
) | |
def forward(self, x): | |
return self.disc(x) | |
# Initialize generator and discriminator | |
gen = Generator(z_dim, image_dim).to(device) | |
disc = Discriminator(image_dim).to(device) | |
# Optimizers | |
opt_gen = optim.Adam(gen.parameters(), lr=lr) | |
opt_disc = optim.Adam(disc.parameters(), lr=lr) | |
# Loss function | |
criterion = nn.BCELoss() | |
# Function to train the model | |
def train_gan(epochs): | |
for epoch in range(epochs): | |
for batch_idx, (real, _) in enumerate(dataloader): | |
real = real.view(-1, 784).to(device) | |
batch_size = real.shape[0] | |
# Train Discriminator | |
noise = torch.randn(batch_size, z_dim).to(device) | |
fake = gen(noise) | |
disc_real = disc(real).view(-1) | |
lossD_real = criterion(disc_real, torch.ones_like(disc_real)) | |
disc_fake = disc(fake).view(-1) | |
lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake)) | |
lossD = (lossD_real + lossD_fake) / 2 | |
disc.zero_grad() | |
lossD.backward(retain_graph=True) | |
opt_disc.step() | |
# Train Generator | |
output = disc(fake).view(-1) | |
lossG = criterion(output, torch.ones_like(output)) | |
gen.zero_grad() | |
lossG.backward() | |
opt_gen.step() | |
st.write(f"Epoch [{epoch+1}/{epochs}] Loss D: {lossD:.4f}, Loss G: {lossG:.4f}") | |
return fake | |
# Streamlit interface | |
st.title("Simple GAN with Epoch Slider") | |
epochs = st.slider("Number of Epochs", 1, 100, 1) | |
if st.button("Train GAN"): | |
fake_images = train_gan(epochs) | |
fake_images = fake_images.view(-1, 1, 28, 28) | |
fake_images = make_grid(fake_images, nrow=8, normalize=True) | |
plt.imshow(fake_images.permute(1, 2, 0).cpu().detach().numpy(), cmap='gray') | |
st.pyplot(plt.gcf()) | |