English
File size: 468 Bytes
ede298f
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from transformers import UperNetForSemanticSegmentation
import torch


def load_model(pretrained_model: str, num_classes: int, device: torch.device) -> torch.nn.Module:
    """
    Loads the UperNet model with a custom number of classes and sends it to the right device.
    """
    model = UperNetForSemanticSegmentation.from_pretrained(
        pretrained_model,
        num_labels=num_classes,
        ignore_mismatched_sizes=True
    )
    return model.to(device)