|
import os |
|
import torch |
|
import torchvision |
|
from ultralytics import YOLO |
|
|
|
|
|
|
|
def build_model(nclasses: int = 2, mode: str = None, segment_model: str = None): |
|
""" |
|
@param[in] nclasses |
|
@param[in] mode set mode for frame classification or uninformative part mask |
|
""" |
|
if mode == 'classify': |
|
|
|
net = torchvision.models.resnet18(num_classes = nclasses) |
|
net.cuda() |
|
if mode == 'mask': |
|
net = YOLO(segment_model) |
|
|
|
return net |
|
|
|
|
|
|
|
net = build_model(nclasses=num_classes, mode='classify') |
|
model_path = 'Video storyboard classification models' |
|
|
|
|
|
net = torch.nn.DataParallel(net) |
|
torch.backends.cudnn.benchmark = True |
|
state = torch.load(model_path, map_location=torch.device('cuda')) |
|
net.load_state_dict(state['net']) |
|
net.eval() |
|
|