English
Antoine1091's picture
Upload folder using huggingface_hub
ede298f verified
from transformers import UperNetForSemanticSegmentation
import torch
def load_model(pretrained_model: str, num_classes: int, device: torch.device) -> torch.nn.Module:
"""
Loads the UperNet model with a custom number of classes and sends it to the right device.
"""
model = UperNetForSemanticSegmentation.from_pretrained(
pretrained_model,
num_labels=num_classes,
ignore_mismatched_sizes=True
)
return model.to(device)