Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torchvision.models as models | |
import torchvision.transforms as transforms | |
from PIL import Image | |
import time | |
class AE(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.encoder = nn.Sequential( | |
nn.Linear(2048, 512), nn.ReLU(), | |
nn.Linear(512, 128) | |
) | |
self.decoder = nn.Sequential( | |
nn.Linear(128, 512), nn.ReLU(), | |
nn.Linear(512, 2048) | |
) | |
def forward(self, x): | |
return self.decoder(self.encoder(x)) | |
resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) | |
resnet.fc = nn.Identity() | |
resnet.eval() | |
autoencoder = AE() | |
def model_option(view, category): | |
if view == "crl": | |
if category == "abdomen": | |
autoencoder.load_state_dict(torch.load('models/abdomen_autoencoder-0.0058.pth')) | |
elif category == "body": | |
autoencoder.load_state_dict(torch.load('models/body_autoencoder-0.0060.pth')) | |
elif category == "diencephalon": | |
autoencoder.load_state_dict(torch.load('models/diencephalon_autoencoder-0.0050.pth')) | |
elif category == "gsac": | |
autoencoder.load_state_dict(torch.load('models/gestation_sac_autoencoder-0.0044.pth')) | |
elif category == "head": | |
autoencoder.load_state_dict(torch.load('models/head_autoencoder-0.0077.pth')) | |
elif category == "lv": | |
autoencoder.load_state_dict(torch.load('models/lateral_ventricle_autoencoder-0.0045.pth')) | |
elif category == "mx": | |
autoencoder.load_state_dict(torch.load('models/maxilla_autoencoder-0.0054.pth')) | |
elif category == "mds": | |
autoencoder.load_state_dict(torch.load('models/mds_mandible_autoencoder-0.0039.pth')) | |
elif category == "mls": | |
autoencoder.load_state_dict(torch.load('models/mls_mandible_ventricle_autoencoder-0.0047.pth')) | |
elif category == "nb": | |
autoencoder.load_state_dict(torch.load('models/nasal_bone_autoencoder-0.0026.pth')) | |
elif category == "ntaps": | |
autoencoder.load_state_dict(torch.load('models/ntaps_autoencoder-0.0032.pth')) | |
elif category == "rbp": | |
autoencoder.load_state_dict(torch.load('models/rhombencephalon_autoencoder-0.0044.pth')) | |
elif category == "thorax": | |
autoencoder.load_state_dict(torch.load('models/thorax_autoencoder-0.0058.pth')) | |
#elif view == "nt": | |
autoencoder.eval() | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], | |
[0.229, 0.224, 0.225]) | |
]) | |
def predict(cropped, view, category): | |
model_option(view, category) | |
img = cropped.convert("RGB") | |
img_tensor = transform(img).unsqueeze(0) | |
with torch.no_grad(): | |
feat = resnet(img_tensor).squeeze().numpy() | |
input_tensor = torch.tensor(feat).float().unsqueeze(0) | |
with torch.no_grad(): | |
recon = autoencoder(input_tensor) | |
error = nn.functional.mse_loss(recon, input_tensor).item() | |
return error | |