import torch.nn as nn from transformers import AutoModel class DiseaseModel(nn.Module): def __init__(self, backbone_id, num_disease_classes): super().__init__() self.backbone = AutoModel.from_pretrained(backbone_id) hidden = self.backbone.config.hidden_size self.head = nn.Linear(hidden, num_disease_classes) def forward(self, pixel_values): feat = self.backbone(pixel_values=pixel_values).last_hidden_state[:, 0] return self.head(feat)