musheff-api / src /services.py
blasisd's picture
Initial commit
9e35b9e
import logging
from typing import Tuple
import torch
from fastapi import HTTPException, status
from PIL import Image
from transformers import PreTrainedModel
from transformers.image_processing_utils import BaseImageProcessor
logger = logging.getLogger(__name__)
async def classify_mushroom_in_image_svc(
img: Image.Image, model: PreTrainedModel, preprocessor: BaseImageProcessor
) -> Tuple[str, str, str]:
"""Service used to classify a mushroom shown in an image.
The mushroom is classified to one of many well known mushroom classes/types,
as well as according to its toxicity profile (i.e. edible or poisonous).
Additionally, a probability is returned showing confidence of classification.
:param img: the input image of the mushroom to be classified
:type img: Image.Image
:param model: the pretrained model
:type model: PretrainedModel
:param preprocessor: the auto preprocessor for image transforms (rescales, crops, normalizations etc.)
:type preprocessor: BaseImageProcessor
:raises HTTPException: Internal Server Error
:return: mushroom_type, toxicity_profile, classification_confidence
:rtype: Tuple[str, str, float]
"""
try:
logger.debug("Loading classification model.")
inputs = preprocessor(img, return_tensors="pt").to(model.device)
# Turn on model evaluation mode and inference mode
model.eval()
with torch.inference_mode():
logger.debug("Starting classification process...")
# Make a prediction on image with an extra dimension and send it to the target device
target_image_pred = model(inputs["pixel_values"])
# Convert logits -> prediction probabilities (using torch.softmax() for multi-class classification)
target_image_pred_probs = torch.softmax(target_image_pred, dim=1)
# model predicts one of the 12 potential mushroom classes
predicted_label = target_image_pred.argmax(dim=1).item()
# Get the label/class name of the prediction made using id2label
class_name = model.config.id2label[predicted_label]
# Split class_name to mushroom type and toxicity profile
class_type, toxicity = class_name.rsplit("_", 1)
# 4 decimal points precision
prob = round(target_image_pred_probs.max().item(), 4)
logger.debug("Finished classification process...")
return class_type, toxicity, prob
except Exception as e:
logger.error(f"Classification process error: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Classification process failed due to an internal error. Contact support if this persists.",
)