File size: 3,398 Bytes
9e35b9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import logging

from pathlib import Path

import torch

from fastapi import HTTPException, status
from PIL import Image
from torchvision import models
from typing import Tuple

import src.config as config


logger = logging.getLogger(__name__)


async def classify_mushroom_in_image_svc(img: Image.Image) -> Tuple[str, str, str]:
    """Service used to classify a mushroom shown in an image.
    The mushroom is classified to one of many well known mushroom classes/types,
    as well as according to its toxicity profile (i.e. edible or poisonous).
    Additionally, a probability is returned showing confidence of classification.

    :param img: the image of the mushroom to be classified
    :type img: Image.Image
    :return: mushroom_type, toxicity_profile, classification_confidence
    :rtype: Tuple[str, str, str]
    """

    try:
        # Device agnostic
        device = "cuda" if torch.cuda.is_available() else "cpu"

        logger.debug("Loading classification model.")

        model_path = config.MODEL_PATH

        # Load saved model checkpoint
        model_state_dict = torch.load(model_path, map_location=device)

        # Get class_names from saved model checkpoint
        model_dirname = Path(model_path).resolve().parent
        with open(model_dirname / "labels.txt", "r") as labels_fp:
            class_names = [line.strip() for line in labels_fp]

        model = models.get_model(config.BASE_MODEL_NAME, num_classes=len(class_names))

        # Load state_dict of saved model
        model.load_state_dict(model_state_dict)

        weights_enum = models.get_model_weights(config.BASE_MODEL_NAME)

        # Get the model's default transforms
        image_transform = weights_enum.DEFAULT.transforms()

        # Make sure the model is on the target device
        model.to(device)

        # Turn on model evaluation mode and inference mode
        model.eval()
        with torch.inference_mode():
            logger.debug("Adapting input image by applying necessary transforms!")
            # Transform and add an extra dimension to image (model requires samples in [batch_size, color_channels, height, width])
            transformed_image = image_transform(img).unsqueeze(dim=0)

            # Make a prediction on image with an extra dimension and send it to the target device
            target_image_pred = model(transformed_image.to(device))

        logger.debug("Starting classification process...")
        # Convert logits -> prediction probabilities (using torch.softmax() for multi-class classification)
        target_image_pred_probs = torch.softmax(target_image_pred, dim=1)

        # Convert prediction probabilities -> prediction labels
        target_image_pred_label = torch.argmax(target_image_pred_probs, dim=1)

        class_name = class_names[target_image_pred_label]

        # Split class_name to mushroom type and toxicity profile
        class_type, toxicity = class_name.rsplit("_", 1)

        # 4 decimal points precision
        prob = round(target_image_pred_probs.max().item(), 4)

        return class_type, toxicity, prob

    except Exception as e:
        logger.error("Classification process error: {e}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail="Classification process failed due to an internal error. Contact support if this persists.",
        )