File size: 2,104 Bytes
70bfb14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# Build Model:
import torch
import torch.nn as nn

EMBEDDING_SIZE = 64

class EmbedDoodle(nn.Module):
        def __init__(self, embedding_size: int):
        # Inputs: 32x32 binary image
        # Outputs: An embedding of said image.
        super().__init__()

        latent_size = 256
        embed_depth = 5

        #self.input_conv = nn.Conv2d(kernel_size=3, in_channels=1, out_channels=16)
        
        def make_cell(in_size: int, hidden_size: int, out_size: int, add_dropout: bool):
            cell = nn.Sequential()
            cell.append(nn.Linear(in_size, hidden_size))
            cell.append(nn.SELU())
            cell.append(nn.Linear(hidden_size, hidden_size))
            if add_dropout:
                cell.append(nn.Dropout())
            cell.append(nn.SELU())
            cell.append(nn.Linear(hidden_size, out_size))
            return cell

        self.preprocess = nn.Sequential(
            nn.Conv2d(kernel_size=3, in_channels=1, out_channels=64),
            nn.Conv2d(kernel_size=3, in_channels=64, out_channels=64),
            nn.SELU(),
            nn.Conv2d(kernel_size=3, in_channels=64, out_channels=64),
            nn.Conv2d(kernel_size=3, in_channels=64, out_channels=64),
            nn.Dropout(),
            nn.SELU(), 
            #nn.AvgPool2d(kernel_size=3), # bx4097

            nn.Flatten(),
            nn.Linear(36864, latent_size),
            nn.SELU(),
        )
        
        self.embedding_path = nn.ModuleList()
        for i in range(0, embed_depth):
            self.embedding_path.append(make_cell(latent_size, latent_size, latent_size, add_dropout=True))
        
        self.embedding_head = nn.Linear(latent_size, embedding_size)

    def forward(self, x):
        x = x.view(-1, 1, 32, 32)
        
        x = self.preprocess(x)
        
        # We should do this with a dot product to combine these to really get the effects of a highway/resnet.
        for c in self.embedding_path:
            x = x + c(x)

        x = self.embedding_head(x)
        embedding = nn.functional.normalize(x, dim=-1)
        return embedding