Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| from torchvision.utils import save_image, make_grid | |
| import matplotlib.pyplot as plt | |
| from matplotlib.animation import FuncAnimation, PillowWriter | |
| import os | |
| import torchvision.transforms as transforms | |
| from torch.utils.data import Dataset | |
| from PIL import Image | |
| class ResidualConvBlock(nn.Module): | |
| def __init__( | |
| self, in_channels: int, out_channels: int, is_res: bool = False | |
| ) -> None: | |
| super().__init__() | |
| # Check if input and output channels are the same for the residual connection | |
| self.same_channels = in_channels == out_channels | |
| # Flag for whether or not to use residual connection | |
| self.is_res = is_res | |
| # First convolutional layer | |
| self.conv1 = nn.Sequential( | |
| nn.Conv2d(in_channels, out_channels, 3, 1, 1), # 3x3 kernel with stride 1 and padding 1 | |
| nn.BatchNorm2d(out_channels), # Batch normalization | |
| nn.GELU(), # GELU activation function | |
| ) | |
| # Second convolutional layer | |
| self.conv2 = nn.Sequential( | |
| nn.Conv2d(out_channels, out_channels, 3, 1, 1), # 3x3 kernel with stride 1 and padding 1 | |
| nn.BatchNorm2d(out_channels), # Batch normalization | |
| nn.GELU(), # GELU activation function | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| # If using residual connection | |
| if self.is_res: | |
| # Apply first convolutional layer | |
| x1 = self.conv1(x) | |
| # Apply second convolutional layer | |
| x2 = self.conv2(x1) | |
| # If input and output channels are the same, add residual connection directly | |
| if self.same_channels: | |
| out = x + x2 | |
| else: | |
| # If not, apply a 1x1 convolutional layer to match dimensions before adding residual connection | |
| shortcut = nn.Conv2d(x.shape[1], x2.shape[1], kernel_size=1, stride=1, padding=0).to(x.device) | |
| out = shortcut(x) + x2 | |
| #print(f"resconv forward: x {x.shape}, x1 {x1.shape}, x2 {x2.shape}, out {out.shape}") | |
| # Normalize output tensor | |
| return out / 1.414 | |
| # If not using residual connection, return output of second convolutional layer | |
| else: | |
| x1 = self.conv1(x) | |
| x2 = self.conv2(x1) | |
| return x2 | |
| # Method to get the number of output channels for this block | |
| def get_out_channels(self): | |
| return self.conv2[0].out_channels | |
| # Method to set the number of output channels for this block | |
| def set_out_channels(self, out_channels): | |
| self.conv1[0].out_channels = out_channels | |
| self.conv2[0].in_channels = out_channels | |
| self.conv2[0].out_channels = out_channels | |
| class UnetUp(nn.Module): | |
| def __init__(self, in_channels, out_channels): | |
| super(UnetUp, self).__init__() | |
| # Create a list of layers for the upsampling block | |
| # The block consists of a ConvTranspose2d layer for upsampling, followed by two ResidualConvBlock layers | |
| layers = [ | |
| nn.ConvTranspose2d(in_channels, out_channels, 2, 2), | |
| ResidualConvBlock(out_channels, out_channels), | |
| ResidualConvBlock(out_channels, out_channels), | |
| ] | |
| # Use the layers to create a sequential model | |
| self.model = nn.Sequential(*layers) | |
| def forward(self, x, skip): | |
| # Concatenate the input tensor x with the skip connection tensor along the channel dimension | |
| x = torch.cat((x, skip), 1) | |
| # Pass the concatenated tensor through the sequential model and return the output | |
| x = self.model(x) | |
| return x | |
| class UnetDown(nn.Module): | |
| def __init__(self, in_channels, out_channels): | |
| super(UnetDown, self).__init__() | |
| # Create a list of layers for the downsampling block | |
| # Each block consists of two ResidualConvBlock layers, followed by a MaxPool2d layer for downsampling | |
| layers = [ResidualConvBlock(in_channels, out_channels), ResidualConvBlock(out_channels, out_channels), nn.MaxPool2d(2)] | |
| # Use the layers to create a sequential model | |
| self.model = nn.Sequential(*layers) | |
| def forward(self, x): | |
| # Pass the input through the sequential model and return the output | |
| return self.model(x) | |
| class EmbedFC(nn.Module): | |
| def __init__(self, input_dim, emb_dim): | |
| super(EmbedFC, self).__init__() | |
| ''' | |
| This class defines a generic one layer feed-forward neural network for embedding input data of | |
| dimensionality input_dim to an embedding space of dimensionality emb_dim. | |
| ''' | |
| self.input_dim = input_dim | |
| # define the layers for the network | |
| layers = [ | |
| nn.Linear(input_dim, emb_dim), | |
| nn.GELU(), | |
| nn.Linear(emb_dim, emb_dim), | |
| ] | |
| # create a PyTorch sequential model consisting of the defined layers | |
| self.model = nn.Sequential(*layers) | |
| def forward(self, x): | |
| # flatten the input tensor | |
| x = x.view(-1, self.input_dim) | |
| # apply the model layers to the flattened tensor | |
| return self.model(x) | |
| def unorm(x): | |
| # unity norm. results in range of [0,1] | |
| # assume x (h,w,3) | |
| xmax = x.max((0,1)) | |
| xmin = x.min((0,1)) | |
| return(x - xmin)/(xmax - xmin) | |
| def norm_all(store, n_t, n_s): | |
| # runs unity norm on all timesteps of all samples | |
| nstore = np.zeros_like(store) | |
| for t in range(n_t): | |
| for s in range(n_s): | |
| nstore[t,s] = unorm(store[t,s]) | |
| return nstore | |
| def norm_torch(x_all): | |
| # runs unity norm on all timesteps of all samples | |
| # input is (n_samples, 3,h,w), the torch image format | |
| x = x_all.cpu().numpy() | |
| xmax = x.max((2,3)) | |
| xmin = x.min((2,3)) | |
| xmax = np.expand_dims(xmax,(2,3)) | |
| xmin = np.expand_dims(xmin,(2,3)) | |
| nstore = (x - xmin)/(xmax - xmin) | |
| return torch.from_numpy(nstore) | |
| def gen_tst_context(n_cfeat): | |
| """ | |
| Generate test context vectors | |
| """ | |
| vec = torch.tensor([ | |
| [1,0,0,0,0], [0,1,0,0,0], [0,0,1,0,0], [0,0,0,1,0], [0,0,0,0,1], [0,0,0,0,0], # human, non-human, food, spell, side-facing | |
| [1,0,0,0,0], [0,1,0,0,0], [0,0,1,0,0], [0,0,0,1,0], [0,0,0,0,1], [0,0,0,0,0], # human, non-human, food, spell, side-facing | |
| [1,0,0,0,0], [0,1,0,0,0], [0,0,1,0,0], [0,0,0,1,0], [0,0,0,0,1], [0,0,0,0,0], # human, non-human, food, spell, side-facing | |
| [1,0,0,0,0], [0,1,0,0,0], [0,0,1,0,0], [0,0,0,1,0], [0,0,0,0,1], [0,0,0,0,0], # human, non-human, food, spell, side-facing | |
| [1,0,0,0,0], [0,1,0,0,0], [0,0,1,0,0], [0,0,0,1,0], [0,0,0,0,1], [0,0,0,0,0], # human, non-human, food, spell, side-facing | |
| [1,0,0,0,0], [0,1,0,0,0], [0,0,1,0,0], [0,0,0,1,0], [0,0,0,0,1], [0,0,0,0,0]] # human, non-human, food, spell, side-facing | |
| ) | |
| return len(vec), vec | |
| def plot_grid(x,n_sample,n_rows,save_dir,w): | |
| # x:(n_sample, 3, h, w) | |
| ncols = n_sample//n_rows | |
| grid = make_grid(norm_torch(x), nrow=ncols) # curiously, nrow is number of columns.. or number of items in the row. | |
| save_image(grid, save_dir + f"run_image_w{w}.png") | |
| print('saved image at ' + save_dir + f"run_image_w{w}.png") | |
| return grid | |
| def plot_sample(x_gen_store,n_sample,nrows,save_dir, fn, w, save=False): | |
| ncols = n_sample//nrows | |
| sx_gen_store = np.moveaxis(x_gen_store,2,4) # change to Numpy image format (h,w,channels) vs (channels,h,w) | |
| nsx_gen_store = norm_all(sx_gen_store, sx_gen_store.shape[0], n_sample) # unity norm to put in range [0,1] for np.imshow | |
| # create gif of images evolving over time, based on x_gen_store | |
| fig, axs = plt.subplots(nrows=nrows, ncols=ncols, sharex=True, sharey=True,figsize=(ncols,nrows)) | |
| def animate_diff(i, store): | |
| print(f'gif animating frame {i} of {store.shape[0]}', end='\r') | |
| plots = [] | |
| for row in range(nrows): | |
| for col in range(ncols): | |
| axs[row, col].clear() | |
| axs[row, col].set_xticks([]) | |
| axs[row, col].set_yticks([]) | |
| plots.append(axs[row, col].imshow(store[i,(row*ncols)+col])) | |
| return plots | |
| ani = FuncAnimation(fig, animate_diff, fargs=[nsx_gen_store], interval=200, blit=False, repeat=True, frames=nsx_gen_store.shape[0]) | |
| plt.close() | |
| if save: | |
| ani.save(save_dir + f"{fn}_w{w}.gif", dpi=100, writer=PillowWriter(fps=5)) | |
| print('saved gif at ' + save_dir + f"{fn}_w{w}.gif") | |
| return ani | |
| class CustomDataset(Dataset): | |
| def __init__(self, sfilename, lfilename, transform, null_context=False): | |
| self.sprites = np.load(sfilename) | |
| self.slabels = np.load(lfilename) | |
| print(f"sprite shape: {self.sprites.shape}") | |
| print(f"labels shape: {self.slabels.shape}") | |
| self.transform = transform | |
| self.null_context = null_context | |
| self.sprites_shape = self.sprites.shape | |
| self.slabel_shape = self.slabels.shape | |
| # Return the number of images in the dataset | |
| def __len__(self): | |
| return len(self.sprites) | |
| # Get the image and label at a given index | |
| def __getitem__(self, idx): | |
| # Return the image and label as a tuple | |
| if self.transform: | |
| image = self.transform(self.sprites[idx]) | |
| if self.null_context: | |
| label = torch.tensor(0).to(torch.int64) | |
| else: | |
| label = torch.tensor(self.slabels[idx]).to(torch.int64) | |
| return (image, label) | |
| def getshapes(self): | |
| # return shapes of data and labels | |
| return self.sprites_shape, self.slabel_shape | |
| transform = transforms.Compose([ | |
| transforms.ToTensor(), # from [0,255] to range [0.0,1.0] | |
| transforms.Normalize((0.5,), (0.5,)) # range [-1,1] | |
| ]) | |