eyupipler commited on
Commit
49c26f6
·
verified ·
1 Parent(s): 7a03700

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +4 -8
model.py CHANGED
@@ -3,10 +3,9 @@ import torch.nn as nn
3
  from huggingface_hub import hf_hub_download
4
 
5
  class SimpleCNN(nn.Module):
6
- def __init__(self, model_type='f', num_classes=4):
7
  super(SimpleCNN, self).__init__()
8
  self.num_classes = num_classes
9
- self.model_type = model_type
10
  if model_type == 'f':
11
  self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
12
  self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
@@ -26,14 +25,11 @@ class SimpleCNN(nn.Module):
26
  self.conv4 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
27
  self.fc1 = nn.Linear(512 * 14 * 14, 1024)
28
  self.dropout = nn.Dropout(0.3)
29
- else:
30
- raise ValueError(f"Unknown model type: {model_type}")
31
-
32
- self.relu = nn.ReLU()
33
- self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
34
  self.fc2 = nn.Linear(self.fc1.out_features, num_classes)
 
 
35
 
36
- def forward(self, x):
37
  x = self.pool(self.relu(self.conv1(x)))
38
  x = self.pool(self.relu(self.conv2(x)))
39
  x = self.pool(self.relu(self.conv3(x)))
 
3
  from huggingface_hub import hf_hub_download
4
 
5
  class SimpleCNN(nn.Module):
6
+ def __init__(self, model_type='c', num_classes=4):
7
  super(SimpleCNN, self).__init__()
8
  self.num_classes = num_classes
 
9
  if model_type == 'f':
10
  self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
11
  self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
 
25
  self.conv4 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
26
  self.fc1 = nn.Linear(512 * 14 * 14, 1024)
27
  self.dropout = nn.Dropout(0.3)
 
 
 
 
 
28
  self.fc2 = nn.Linear(self.fc1.out_features, num_classes)
29
+ self.relu = nn.ReLU()
30
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
31
 
32
+ def forward(self, x):
33
  x = self.pool(self.relu(self.conv1(x)))
34
  x = self.pool(self.relu(self.conv2(x)))
35
  x = self.pool(self.relu(self.conv3(x)))