# Build Model: import torch import torch.nn as nn EMBEDDING_SIZE = 64 class EmbedDoodle(nn.Module): def __init__(self, embedding_size: int): # Inputs: 32x32 binary image # Outputs: An embedding of said image. super().__init__() latent_size = 256 embed_depth = 5 #self.input_conv = nn.Conv2d(kernel_size=3, in_channels=1, out_channels=16) def make_cell(in_size: int, hidden_size: int, out_size: int, add_dropout: bool): cell = nn.Sequential() cell.append(nn.Linear(in_size, hidden_size)) cell.append(nn.SELU()) cell.append(nn.Linear(hidden_size, hidden_size)) if add_dropout: cell.append(nn.Dropout()) cell.append(nn.SELU()) cell.append(nn.Linear(hidden_size, out_size)) return cell self.preprocess = nn.Sequential( nn.Conv2d(kernel_size=3, in_channels=1, out_channels=64), nn.Conv2d(kernel_size=3, in_channels=64, out_channels=64), nn.SELU(), nn.Conv2d(kernel_size=3, in_channels=64, out_channels=64), nn.Conv2d(kernel_size=3, in_channels=64, out_channels=64), nn.Dropout(), nn.SELU(), #nn.AvgPool2d(kernel_size=3), # bx4097 nn.Flatten(), nn.Linear(36864, latent_size), nn.SELU(), ) self.embedding_path = nn.ModuleList() for i in range(0, embed_depth): self.embedding_path.append(make_cell(latent_size, latent_size, latent_size, add_dropout=True)) self.embedding_head = nn.Linear(latent_size, embedding_size) def forward(self, x): x = x.view(-1, 1, 32, 32) x = self.preprocess(x) # We should do this with a dot product to combine these to really get the effects of a highway/resnet. for c in self.embedding_path: x = x + c(x) x = self.embedding_head(x) embedding = nn.functional.normalize(x, dim=-1) return embedding