from transformers import UperNetForSemanticSegmentation import torch def load_model(pretrained_model: str, num_classes: int, device: torch.device) -> torch.nn.Module: """ Loads the UperNet model with a custom number of classes and sends it to the right device. """ model = UperNetForSemanticSegmentation.from_pretrained( pretrained_model, num_labels=num_classes, ignore_mismatched_sizes=True ) return model.to(device)