Spaces:
Runtime error
Runtime error
Commit
Β·
cded863
1
Parent(s):
662fda2
- app/model_architectures.py +12 -3
app/model_architectures.py
CHANGED
|
@@ -6,9 +6,12 @@ class ResNet50(nn.Module):
|
|
| 6 |
def __init__(self, num_classes=7, channels=3):
|
| 7 |
super(ResNet50, self).__init__()
|
| 8 |
self.resnet = models.resnet50(pretrained=True)
|
|
|
|
| 9 |
# Modify the first convolutional layer if channels != 3
|
| 10 |
if channels != 3:
|
| 11 |
self.resnet.conv1 = nn.Conv2d(channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
|
|
|
|
|
|
| 12 |
num_features = self.resnet.fc.in_features
|
| 13 |
self.resnet.fc = nn.Linear(num_features, num_classes)
|
| 14 |
|
|
@@ -16,6 +19,7 @@ class ResNet50(nn.Module):
|
|
| 16 |
return self.resnet(x)
|
| 17 |
|
| 18 |
def extract_features(self, x):
|
|
|
|
| 19 |
x = self.resnet.conv1(x)
|
| 20 |
x = self.resnet.bn1(x)
|
| 21 |
x = self.resnet.relu(x)
|
|
@@ -39,8 +43,13 @@ class LSTMPyTorch(nn.Module):
|
|
| 39 |
self.fc = nn.Linear(hidden_size, num_classes)
|
| 40 |
|
| 41 |
def forward(self, x):
|
|
|
|
| 42 |
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
|
| 43 |
c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
def __init__(self, num_classes=7, channels=3):
|
| 7 |
super(ResNet50, self).__init__()
|
| 8 |
self.resnet = models.resnet50(pretrained=True)
|
| 9 |
+
|
| 10 |
# Modify the first convolutional layer if channels != 3
|
| 11 |
if channels != 3:
|
| 12 |
self.resnet.conv1 = nn.Conv2d(channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
| 13 |
+
|
| 14 |
+
# Replace the fully connected layer
|
| 15 |
num_features = self.resnet.fc.in_features
|
| 16 |
self.resnet.fc = nn.Linear(num_features, num_classes)
|
| 17 |
|
|
|
|
| 19 |
return self.resnet(x)
|
| 20 |
|
| 21 |
def extract_features(self, x):
|
| 22 |
+
# Feature extraction using layers up to layer4
|
| 23 |
x = self.resnet.conv1(x)
|
| 24 |
x = self.resnet.bn1(x)
|
| 25 |
x = self.resnet.relu(x)
|
|
|
|
| 43 |
self.fc = nn.Linear(hidden_size, num_classes)
|
| 44 |
|
| 45 |
def forward(self, x):
|
| 46 |
+
# Initialize hidden state and cell state with zeros
|
| 47 |
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
|
| 48 |
c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
|
| 49 |
+
|
| 50 |
+
# Forward propagate LSTM
|
| 51 |
+
out, _ = self.lstm(x, (h0, c0)) # out: tensor of shape (batch_size, seq_length, hidden_size)
|
| 52 |
+
|
| 53 |
+
# Pass through the fully connected layer
|
| 54 |
+
out = self.fc(out[:, -1, :]) # Only the last time step output
|
| 55 |
+
return out
|