""" Builds Pytorch model """ import torch import torchvision.models from torch import nn class ResNet101(nn.Module): """ ResNet101 model specified for the binary problem. The according transforms were taken from pytorch.org. """ def __init__(self): super().__init__() self.weights = torchvision.models.ResNet101_Weights.DEFAULT self.transforms = self.weights.transforms self.resnet = torchvision.models.resnet101(weights=self.weights) for param in self.resnet.parameters(): param.requires_grad = False self.resnet.fc = nn.Linear(in_features=2048, out_features=1) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.resnet(x) return x