import os import torch import torchvision from ultralytics import YOLO def build_model(nclasses: int = 2, mode: str = None, segment_model: str = None): """ @param[in] nclasses @param[in] mode set mode for frame classification or uninformative part mask """ if mode == 'classify': #net of Resnet18 net = torchvision.models.resnet18(num_classes = nclasses) net.cuda() if mode == 'mask': net = YOLO(segment_model) return net net = build_model(nclasses=num_classes, mode='classify') model_path = 'Video storyboard classification models' # Enable multi-GPU support net = torch.nn.DataParallel(net) torch.backends.cudnn.benchmark = True state = torch.load(model_path, map_location=torch.device('cuda')) net.load_state_dict(state['net']) net.eval()