Spaces:
Sleeping
Sleeping
File size: 3,398 Bytes
9e35b9e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import logging
from pathlib import Path
import torch
from fastapi import HTTPException, status
from PIL import Image
from torchvision import models
from typing import Tuple
import src.config as config
logger = logging.getLogger(__name__)
async def classify_mushroom_in_image_svc(img: Image.Image) -> Tuple[str, str, str]:
"""Service used to classify a mushroom shown in an image.
The mushroom is classified to one of many well known mushroom classes/types,
as well as according to its toxicity profile (i.e. edible or poisonous).
Additionally, a probability is returned showing confidence of classification.
:param img: the image of the mushroom to be classified
:type img: Image.Image
:return: mushroom_type, toxicity_profile, classification_confidence
:rtype: Tuple[str, str, str]
"""
try:
# Device agnostic
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.debug("Loading classification model.")
model_path = config.MODEL_PATH
# Load saved model checkpoint
model_state_dict = torch.load(model_path, map_location=device)
# Get class_names from saved model checkpoint
model_dirname = Path(model_path).resolve().parent
with open(model_dirname / "labels.txt", "r") as labels_fp:
class_names = [line.strip() for line in labels_fp]
model = models.get_model(config.BASE_MODEL_NAME, num_classes=len(class_names))
# Load state_dict of saved model
model.load_state_dict(model_state_dict)
weights_enum = models.get_model_weights(config.BASE_MODEL_NAME)
# Get the model's default transforms
image_transform = weights_enum.DEFAULT.transforms()
# Make sure the model is on the target device
model.to(device)
# Turn on model evaluation mode and inference mode
model.eval()
with torch.inference_mode():
logger.debug("Adapting input image by applying necessary transforms!")
# Transform and add an extra dimension to image (model requires samples in [batch_size, color_channels, height, width])
transformed_image = image_transform(img).unsqueeze(dim=0)
# Make a prediction on image with an extra dimension and send it to the target device
target_image_pred = model(transformed_image.to(device))
logger.debug("Starting classification process...")
# Convert logits -> prediction probabilities (using torch.softmax() for multi-class classification)
target_image_pred_probs = torch.softmax(target_image_pred, dim=1)
# Convert prediction probabilities -> prediction labels
target_image_pred_label = torch.argmax(target_image_pred_probs, dim=1)
class_name = class_names[target_image_pred_label]
# Split class_name to mushroom type and toxicity profile
class_type, toxicity = class_name.rsplit("_", 1)
# 4 decimal points precision
prob = round(target_image_pred_probs.max().item(), 4)
return class_type, toxicity, prob
except Exception as e:
logger.error("Classification process error: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Classification process failed due to an internal error. Contact support if this persists.",
)
|