musheff-api / src /backup_services.py
blasisd's picture
Initial commit
9e35b9e
raw
history blame
3.4 kB
import logging
from pathlib import Path
import torch
from fastapi import HTTPException, status
from PIL import Image
from torchvision import models
from typing import Tuple
import src.config as config
logger = logging.getLogger(__name__)
async def classify_mushroom_in_image_svc(img: Image.Image) -> Tuple[str, str, str]:
"""Service used to classify a mushroom shown in an image.
The mushroom is classified to one of many well known mushroom classes/types,
as well as according to its toxicity profile (i.e. edible or poisonous).
Additionally, a probability is returned showing confidence of classification.
:param img: the image of the mushroom to be classified
:type img: Image.Image
:return: mushroom_type, toxicity_profile, classification_confidence
:rtype: Tuple[str, str, str]
"""
try:
# Device agnostic
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.debug("Loading classification model.")
model_path = config.MODEL_PATH
# Load saved model checkpoint
model_state_dict = torch.load(model_path, map_location=device)
# Get class_names from saved model checkpoint
model_dirname = Path(model_path).resolve().parent
with open(model_dirname / "labels.txt", "r") as labels_fp:
class_names = [line.strip() for line in labels_fp]
model = models.get_model(config.BASE_MODEL_NAME, num_classes=len(class_names))
# Load state_dict of saved model
model.load_state_dict(model_state_dict)
weights_enum = models.get_model_weights(config.BASE_MODEL_NAME)
# Get the model's default transforms
image_transform = weights_enum.DEFAULT.transforms()
# Make sure the model is on the target device
model.to(device)
# Turn on model evaluation mode and inference mode
model.eval()
with torch.inference_mode():
logger.debug("Adapting input image by applying necessary transforms!")
# Transform and add an extra dimension to image (model requires samples in [batch_size, color_channels, height, width])
transformed_image = image_transform(img).unsqueeze(dim=0)
# Make a prediction on image with an extra dimension and send it to the target device
target_image_pred = model(transformed_image.to(device))
logger.debug("Starting classification process...")
# Convert logits -> prediction probabilities (using torch.softmax() for multi-class classification)
target_image_pred_probs = torch.softmax(target_image_pred, dim=1)
# Convert prediction probabilities -> prediction labels
target_image_pred_label = torch.argmax(target_image_pred_probs, dim=1)
class_name = class_names[target_image_pred_label]
# Split class_name to mushroom type and toxicity profile
class_type, toxicity = class_name.rsplit("_", 1)
# 4 decimal points precision
prob = round(target_image_pred_probs.max().item(), 4)
return class_type, toxicity, prob
except Exception as e:
logger.error("Classification process error: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Classification process failed due to an internal error. Contact support if this persists.",
)