mmesa-gpu-gitex / app /model_architectures.py
vitorcalvi's picture
1
cded863
import torch
import torch.nn as nn
import torchvision.models as models
class ResNet50(nn.Module):
def __init__(self, num_classes=7, channels=3):
super(ResNet50, self).__init__()
self.resnet = models.resnet50(pretrained=True)
# Modify the first convolutional layer if channels != 3
if channels != 3:
self.resnet.conv1 = nn.Conv2d(channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
# Replace the fully connected layer
num_features = self.resnet.fc.in_features
self.resnet.fc = nn.Linear(num_features, num_classes)
def forward(self, x):
return self.resnet(x)
def extract_features(self, x):
# Feature extraction using layers up to layer4
x = self.resnet.conv1(x)
x = self.resnet.bn1(x)
x = self.resnet.relu(x)
x = self.resnet.maxpool(x)
x = self.resnet.layer1(x)
x = self.resnet.layer2(x)
x = self.resnet.layer3(x)
x = self.resnet.layer4(x)
x = self.resnet.avgpool(x)
x = torch.flatten(x, 1)
return x
class LSTMPyTorch(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, num_classes):
super(LSTMPyTorch, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, num_classes)
def forward(self, x):
# Initialize hidden state and cell state with zeros
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
# Forward propagate LSTM
out, _ = self.lstm(x, (h0, c0)) # out: tensor of shape (batch_size, seq_length, hidden_size)
# Pass through the fully connected layer
out = self.fc(out[:, -1, :]) # Only the last time step output
return out