chengan98 commited on
Commit
9bc6867
·
verified ·
1 Parent(s): 54bc442

Update model_loader.py

Browse files
Files changed (1) hide show
  1. model_loader.py +0 -20
model_loader.py CHANGED
@@ -19,26 +19,6 @@ def build_model(nclasses: int = 2, mode: str = None, segment_model: str = None):
19
 
20
  return net
21
 
22
- def build_SurgFM(nclasses: int = 2, pretrained: bool = True, pretrained_weights = None):
23
-
24
-
25
- #net of ConvNext
26
- net = torchvision.models.convnext_large(weights='DEFAULT')
27
- input_emdim = net.classifier[2].in_features
28
- net.classifier[2] = nn.Identity()
29
-
30
- if os.path.isfile(pretrained_weights):
31
- state_dict = torch.load(pretrained_weights, map_location="cpu")
32
- state_dict = state_dict['teacher']
33
-
34
- # remove `backbone.` prefix induced by multicrop wrapper
35
- state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items() if k.startswith('backbone.')}
36
- msg = net.load_state_dict(state_dict, strict=False)
37
- print(msg, input_emdim)
38
-
39
- net.cuda()
40
-
41
- return net
42
 
43
 
44
  net = build_model(nclasses=num_classes, mode='classify')
 
19
 
20
  return net
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
 
24
  net = build_model(nclasses=num_classes, mode='classify')