JosephCatrambone commited on
Commit
70bfb14
·
verified ·
1 Parent(s): 4245406

Add py file for architecture of embedding model.

Browse files
Files changed (1) hide show
  1. tiny_doodle_embedding_model.py +62 -0
tiny_doodle_embedding_model.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Build Model:
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ EMBEDDING_SIZE = 64
6
+
7
+ class EmbedDoodle(nn.Module):
8
+ def __init__(self, embedding_size: int):
9
+ # Inputs: 32x32 binary image
10
+ # Outputs: An embedding of said image.
11
+ super().__init__()
12
+
13
+ latent_size = 256
14
+ embed_depth = 5
15
+
16
+ #self.input_conv = nn.Conv2d(kernel_size=3, in_channels=1, out_channels=16)
17
+
18
+ def make_cell(in_size: int, hidden_size: int, out_size: int, add_dropout: bool):
19
+ cell = nn.Sequential()
20
+ cell.append(nn.Linear(in_size, hidden_size))
21
+ cell.append(nn.SELU())
22
+ cell.append(nn.Linear(hidden_size, hidden_size))
23
+ if add_dropout:
24
+ cell.append(nn.Dropout())
25
+ cell.append(nn.SELU())
26
+ cell.append(nn.Linear(hidden_size, out_size))
27
+ return cell
28
+
29
+ self.preprocess = nn.Sequential(
30
+ nn.Conv2d(kernel_size=3, in_channels=1, out_channels=64),
31
+ nn.Conv2d(kernel_size=3, in_channels=64, out_channels=64),
32
+ nn.SELU(),
33
+ nn.Conv2d(kernel_size=3, in_channels=64, out_channels=64),
34
+ nn.Conv2d(kernel_size=3, in_channels=64, out_channels=64),
35
+ nn.Dropout(),
36
+ nn.SELU(),
37
+ #nn.AvgPool2d(kernel_size=3), # bx4097
38
+
39
+ nn.Flatten(),
40
+ nn.Linear(36864, latent_size),
41
+ nn.SELU(),
42
+ )
43
+
44
+ self.embedding_path = nn.ModuleList()
45
+ for i in range(0, embed_depth):
46
+ self.embedding_path.append(make_cell(latent_size, latent_size, latent_size, add_dropout=True))
47
+
48
+ self.embedding_head = nn.Linear(latent_size, embedding_size)
49
+
50
+ def forward(self, x):
51
+ x = x.view(-1, 1, 32, 32)
52
+
53
+ x = self.preprocess(x)
54
+
55
+ # We should do this with a dot product to combine these to really get the effects of a highway/resnet.
56
+ for c in self.embedding_path:
57
+ x = x + c(x)
58
+
59
+ x = self.embedding_head(x)
60
+ embedding = nn.functional.normalize(x, dim=-1)
61
+ return embedding
62
+