File size: 494 Bytes
49a5910
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
import torch.nn as nn
from transformers import AutoModel

class DiseaseModel(nn.Module):
    def __init__(self, backbone_id, num_disease_classes):
        super().__init__()
        self.backbone = AutoModel.from_pretrained(backbone_id)
        hidden = self.backbone.config.hidden_size
        self.head = nn.Linear(hidden, num_disease_classes)

    def forward(self, pixel_values):
        feat = self.backbone(pixel_values=pixel_values).last_hidden_state[:, 0]
        return self.head(feat)