Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import torch.nn as nn | |
| import torchvision.transforms as T | |
| from torchvision.utils import make_grid | |
| import torch | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def inference_gan(): | |
| generator = torch.jit.load("models/mnist-G-torchscript.pt").to(device) | |
| x = torch.randn(30, 256, device='cuda') | |
| y = generator(x) | |
| y = y.view(-1, 1, 28, 28) # reshape y to have 1 channel | |
| grid = make_grid(y.cpu().detach(), nrow=8) | |
| img = T.functional.to_pil_image(grid) | |
| return img | |
| def inference_dcgan(): | |
| generator = torch.jit.load("models/animefacedataset-G2-torchscript.pt").to(device) | |
| def denorm(img_tensors): | |
| stats = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5) | |
| return img_tensors * stats[1][0] + stats[0][0] | |
| x = torch.randn(64, 128, 1, 1, device='cuda') | |
| y = generator(x) | |
| y = y.view(-1, 3, 64, 64) # reshape y to have 3 channels | |
| grid = make_grid(denorm(y.cpu().detach()), nrow=8) | |
| img = T.functional.to_pil_image(grid) | |
| return img | |
| def inference_both(): | |
| inference_gan() | |
| inference_dcgan() | |
| st.markdown("# Image Generation with GANs and DCGANs") | |
| st.button("Generate Images", on_click=inference_both) | |
| st.image(inference_dcgan(), caption="", use_column_width=True) | |
| st.image(inference_gan(), caption="", use_column_width=True) |