|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
EMBEDDING_SIZE = 64 |
|
|
|
class EmbedDoodle(nn.Module): |
|
def __init__(self, embedding_size: int): |
|
|
|
|
|
super().__init__() |
|
|
|
latent_size = 256 |
|
embed_depth = 5 |
|
|
|
|
|
|
|
def make_cell(in_size: int, hidden_size: int, out_size: int, add_dropout: bool): |
|
cell = nn.Sequential() |
|
cell.append(nn.Linear(in_size, hidden_size)) |
|
cell.append(nn.SELU()) |
|
cell.append(nn.Linear(hidden_size, hidden_size)) |
|
if add_dropout: |
|
cell.append(nn.Dropout()) |
|
cell.append(nn.SELU()) |
|
cell.append(nn.Linear(hidden_size, out_size)) |
|
return cell |
|
|
|
self.preprocess = nn.Sequential( |
|
nn.Conv2d(kernel_size=3, in_channels=1, out_channels=64), |
|
nn.Conv2d(kernel_size=3, in_channels=64, out_channels=64), |
|
nn.SELU(), |
|
nn.Conv2d(kernel_size=3, in_channels=64, out_channels=64), |
|
nn.Conv2d(kernel_size=3, in_channels=64, out_channels=64), |
|
nn.Dropout(), |
|
nn.SELU(), |
|
|
|
|
|
nn.Flatten(), |
|
nn.Linear(36864, latent_size), |
|
nn.SELU(), |
|
) |
|
|
|
self.embedding_path = nn.ModuleList() |
|
for i in range(0, embed_depth): |
|
self.embedding_path.append(make_cell(latent_size, latent_size, latent_size, add_dropout=True)) |
|
|
|
self.embedding_head = nn.Linear(latent_size, embedding_size) |
|
|
|
def forward(self, x): |
|
x = x.view(-1, 1, 32, 32) |
|
|
|
x = self.preprocess(x) |
|
|
|
|
|
for c in self.embedding_path: |
|
x = x + c(x) |
|
|
|
x = self.embedding_head(x) |
|
embedding = nn.functional.normalize(x, dim=-1) |
|
return embedding |
|
|
|
|