eyupipler commited on
Commit
126eac4
·
verified ·
1 Parent(s): b4f9ca6

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +39 -0
model.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class SimpleCNN(nn.Module):
5
+ def __init__(self, num_classes=6):
6
+ super(SimpleCNN, self).__init__()
7
+ self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
8
+ self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
9
+ self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
10
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
11
+ self.relu = nn.ReLU()
12
+ self.dropout = nn.Dropout(0.5)
13
+ self._initialize_fc(num_classes)
14
+
15
+ def _initialize_fc(self, num_classes):
16
+ dummy_input = torch.zeros(1, 3, 448, 448)
17
+ x = self.pool(self.relu(self.conv1(dummy_input)))
18
+ x = self.pool(self.relu(self.conv2(x)))
19
+ x = self.pool(self.relu(self.conv3(x)))
20
+ x = x.view(x.size(0), -1)
21
+ flattened_size = x.shape[1]
22
+ self.fc1 = nn.Linear(flattened_size, 512)
23
+ self.fc2 = nn.Linear(512, num_classes)
24
+
25
+ def forward(self, x):
26
+ x = self.pool(self.relu(self.conv1(x)))
27
+ x = self.pool(self.relu(self.conv2(x)))
28
+ x = self.pool(self.relu(self.conv3(x)))
29
+ x = x.view(x.size(0), -1)
30
+ x = self.dropout(self.relu(self.fc1(x)))
31
+ x = self.fc2(x)
32
+ return x
33
+
34
+ def load_model(weights_path, device='cpu'):
35
+ model = SimpleCNN(num_classes=6).to(device)
36
+ state = torch.load(weights_path, map_location=device)
37
+ model.load_state_dict(state)
38
+ model.eval()
39
+ return model