crop_health_monitor / model /architecture.py
CZerion's picture
Create model/architecture.py
49a5910 verified
raw
history blame contribute delete
494 Bytes
import torch.nn as nn
from transformers import AutoModel
class DiseaseModel(nn.Module):
def __init__(self, backbone_id, num_disease_classes):
super().__init__()
self.backbone = AutoModel.from_pretrained(backbone_id)
hidden = self.backbone.config.hidden_size
self.head = nn.Linear(hidden, num_disease_classes)
def forward(self, pixel_values):
feat = self.backbone(pixel_values=pixel_values).last_hidden_state[:, 0]
return self.head(feat)