DHEIVER commited on
Commit
d7abc8f
·
1 Parent(s): 774e216

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +27 -0
model.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Builds Pytorch model
3
+ """
4
+ import torch
5
+ import torchvision.models
6
+ from torch import nn
7
+
8
+
9
+ class ResNet101(nn.Module):
10
+ """
11
+ ResNet101 model specified for the binary problem. The according transforms were taken from pytorch.org.
12
+ """
13
+
14
+ def __init__(self):
15
+ super().__init__()
16
+ self.weights = torchvision.models.ResNet101_Weights.DEFAULT
17
+ self.transforms = self.weights.transforms
18
+ self.resnet = torchvision.models.resnet101(weights=self.weights)
19
+
20
+ for param in self.resnet.parameters():
21
+ param.requires_grad = False
22
+
23
+ self.resnet.fc = nn.Linear(in_features=2048, out_features=1)
24
+
25
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
26
+ x = self.resnet(x)
27
+ return x