File size: 813 Bytes
54bc442
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import os
import torch
import torchvision
from ultralytics import YOLO



def build_model(nclasses: int = 2, mode: str = None, segment_model: str = None):
    """
    @param[in]  nclasses
    @param[in]  mode  set mode for frame classification or uninformative part mask
    """
    if mode == 'classify':
        #net of Resnet18
        net = torchvision.models.resnet18(num_classes = nclasses)
        net.cuda()
    if mode == 'mask':
        net = YOLO(segment_model)

    return net



net = build_model(nclasses=num_classes, mode='classify')
model_path = 'Video storyboard classification models'

# Enable multi-GPU support
net = torch.nn.DataParallel(net)
torch.backends.cudnn.benchmark = True
state = torch.load(model_path, map_location=torch.device('cuda'))
net.load_state_dict(state['net'])
net.eval()