File size: 1,383 Bytes
733fcd8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
import torch.nn as nn
from typing import Tuple, List
from torchvision import models, transforms
from PIL import Image
from config import CLASS_LABELS, MODEL_PATH
import torch.nn.functional as F


def get_model():
    model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT)
    model.classifier[1] = nn.Linear(model.classifier[1].in_features, len(CLASS_LABELS))
    model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu')))
    model.eval()
    return model

def get_model_by_name(model_path: str, num_classes: int):
    model = models.efficientnet_b0(weights=None)
    model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
    model.load_state_dict(torch.load(model_path, map_location='cpu'))
    
    model.eval()
    return model


def predict(image: Image.Image, model, class_labels: List[str] = None) -> Tuple[str, float]:
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])
    image_tensor = transform(image).unsqueeze(0)

    with torch.no_grad():
        output = model(image_tensor)
        probabilities = F.softmax(output, dim=1)
        confidence, pred = torch.max(probabilities, dim=1)
    print(pred)

    if class_labels is None:
        class_labels = CLASS_LABELS

    return class_labels[pred.item()], confidence.item()