File size: 2,754 Bytes
9e35b9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import logging

from typing import Tuple

import torch

from fastapi import HTTPException, status
from PIL import Image
from transformers import PreTrainedModel
from transformers.image_processing_utils import BaseImageProcessor


logger = logging.getLogger(__name__)


async def classify_mushroom_in_image_svc(
    img: Image.Image, model: PreTrainedModel, preprocessor: BaseImageProcessor
) -> Tuple[str, str, str]:
    """Service used to classify a mushroom shown in an image.
    The mushroom is classified to one of many well known mushroom classes/types,
    as well as according to its toxicity profile (i.e. edible or poisonous).
    Additionally, a probability is returned showing confidence of classification.

    :param img: the input image of the mushroom to be classified
    :type img: Image.Image
    :param model: the pretrained model
    :type model: PretrainedModel
    :param preprocessor: the auto preprocessor for image transforms (rescales, crops, normalizations etc.)
    :type preprocessor: BaseImageProcessor
    :raises HTTPException: Internal Server Error
    :return: mushroom_type, toxicity_profile, classification_confidence
    :rtype: Tuple[str, str, float]
    """

    try:

        logger.debug("Loading classification model.")

        inputs = preprocessor(img, return_tensors="pt").to(model.device)

        # Turn on model evaluation mode and inference mode
        model.eval()
        with torch.inference_mode():
            logger.debug("Starting classification process...")

            # Make a prediction on image with an extra dimension and send it to the target device
            target_image_pred = model(inputs["pixel_values"])

        # Convert logits -> prediction probabilities (using torch.softmax() for multi-class classification)
        target_image_pred_probs = torch.softmax(target_image_pred, dim=1)

        # model predicts one of the 12 potential mushroom classes
        predicted_label = target_image_pred.argmax(dim=1).item()

        # Get the label/class name of the prediction made using id2label
        class_name = model.config.id2label[predicted_label]

        # Split class_name to mushroom type and toxicity profile
        class_type, toxicity = class_name.rsplit("_", 1)

        # 4 decimal points precision
        prob = round(target_image_pred_probs.max().item(), 4)

        logger.debug("Finished classification process...")
        return class_type, toxicity, prob

    except Exception as e:
        logger.error(f"Classification process error: {e}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail="Classification process failed due to an internal error. Contact support if this persists.",
        )