Spaces:
Running
Running
import torch.nn as nn | |
from transformers import AutoModel | |
class DiseaseModel(nn.Module): | |
def __init__(self, backbone_id, num_disease_classes): | |
super().__init__() | |
self.backbone = AutoModel.from_pretrained(backbone_id) | |
hidden = self.backbone.config.hidden_size | |
self.head = nn.Linear(hidden, num_disease_classes) | |
def forward(self, pixel_values): | |
feat = self.backbone(pixel_values=pixel_values).last_hidden_state[:, 0] | |
return self.head(feat) |