vitorcalvi commited on
Commit
cded863
Β·
1 Parent(s): 662fda2
Files changed (1) hide show
  1. app/model_architectures.py +12 -3
app/model_architectures.py CHANGED
@@ -6,9 +6,12 @@ class ResNet50(nn.Module):
6
  def __init__(self, num_classes=7, channels=3):
7
  super(ResNet50, self).__init__()
8
  self.resnet = models.resnet50(pretrained=True)
 
9
  # Modify the first convolutional layer if channels != 3
10
  if channels != 3:
11
  self.resnet.conv1 = nn.Conv2d(channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
 
 
12
  num_features = self.resnet.fc.in_features
13
  self.resnet.fc = nn.Linear(num_features, num_classes)
14
 
@@ -16,6 +19,7 @@ class ResNet50(nn.Module):
16
  return self.resnet(x)
17
 
18
  def extract_features(self, x):
 
19
  x = self.resnet.conv1(x)
20
  x = self.resnet.bn1(x)
21
  x = self.resnet.relu(x)
@@ -39,8 +43,13 @@ class LSTMPyTorch(nn.Module):
39
  self.fc = nn.Linear(hidden_size, num_classes)
40
 
41
  def forward(self, x):
 
42
  h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
43
  c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
44
- out, _ = self.lstm(x, (h0, c0))
45
- out = self.fc(out[:, -1, :])
46
- return out
 
 
 
 
 
6
  def __init__(self, num_classes=7, channels=3):
7
  super(ResNet50, self).__init__()
8
  self.resnet = models.resnet50(pretrained=True)
9
+
10
  # Modify the first convolutional layer if channels != 3
11
  if channels != 3:
12
  self.resnet.conv1 = nn.Conv2d(channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
13
+
14
+ # Replace the fully connected layer
15
  num_features = self.resnet.fc.in_features
16
  self.resnet.fc = nn.Linear(num_features, num_classes)
17
 
 
19
  return self.resnet(x)
20
 
21
  def extract_features(self, x):
22
+ # Feature extraction using layers up to layer4
23
  x = self.resnet.conv1(x)
24
  x = self.resnet.bn1(x)
25
  x = self.resnet.relu(x)
 
43
  self.fc = nn.Linear(hidden_size, num_classes)
44
 
45
  def forward(self, x):
46
+ # Initialize hidden state and cell state with zeros
47
  h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
48
  c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
49
+
50
+ # Forward propagate LSTM
51
+ out, _ = self.lstm(x, (h0, c0)) # out: tensor of shape (batch_size, seq_length, hidden_size)
52
+
53
+ # Pass through the fully connected layer
54
+ out = self.fc(out[:, -1, :]) # Only the last time step output
55
+ return out