from transformers import UperNetForSemanticSegmentation | |
import torch | |
def load_model(pretrained_model: str, num_classes: int, device: torch.device) -> torch.nn.Module: | |
""" | |
Loads the UperNet model with a custom number of classes and sends it to the right device. | |
""" | |
model = UperNetForSemanticSegmentation.from_pretrained( | |
pretrained_model, | |
num_labels=num_classes, | |
ignore_mismatched_sizes=True | |
) | |
return model.to(device) | |