chengan commited on
Commit
54bc442
·
1 Parent(s): 842298a

code update

Browse files
Files changed (1) hide show
  1. model_loader.py +52 -0
model_loader.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torchvision
4
+ from ultralytics import YOLO
5
+
6
+
7
+
8
+ def build_model(nclasses: int = 2, mode: str = None, segment_model: str = None):
9
+ """
10
+ @param[in] nclasses
11
+ @param[in] mode set mode for frame classification or uninformative part mask
12
+ """
13
+ if mode == 'classify':
14
+ #net of Resnet18
15
+ net = torchvision.models.resnet18(num_classes = nclasses)
16
+ net.cuda()
17
+ if mode == 'mask':
18
+ net = YOLO(segment_model)
19
+
20
+ return net
21
+
22
+ def build_SurgFM(nclasses: int = 2, pretrained: bool = True, pretrained_weights = None):
23
+
24
+
25
+ #net of ConvNext
26
+ net = torchvision.models.convnext_large(weights='DEFAULT')
27
+ input_emdim = net.classifier[2].in_features
28
+ net.classifier[2] = nn.Identity()
29
+
30
+ if os.path.isfile(pretrained_weights):
31
+ state_dict = torch.load(pretrained_weights, map_location="cpu")
32
+ state_dict = state_dict['teacher']
33
+
34
+ # remove `backbone.` prefix induced by multicrop wrapper
35
+ state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items() if k.startswith('backbone.')}
36
+ msg = net.load_state_dict(state_dict, strict=False)
37
+ print(msg, input_emdim)
38
+
39
+ net.cuda()
40
+
41
+ return net
42
+
43
+
44
+ net = build_model(nclasses=num_classes, mode='classify')
45
+ model_path = 'Video storyboard classification models'
46
+
47
+ # Enable multi-GPU support
48
+ net = torch.nn.DataParallel(net)
49
+ torch.backends.cudnn.benchmark = True
50
+ state = torch.load(model_path, map_location=torch.device('cuda'))
51
+ net.load_state_dict(state['net'])
52
+ net.eval()