musheff-api / src /services.py
blasisd's picture
Initial commit
9e35b9e
raw
history blame
2.75 kB
import logging
from typing import Tuple
import torch
from fastapi import HTTPException, status
from PIL import Image
from transformers import PreTrainedModel
from transformers.image_processing_utils import BaseImageProcessor
logger = logging.getLogger(__name__)
async def classify_mushroom_in_image_svc(
img: Image.Image, model: PreTrainedModel, preprocessor: BaseImageProcessor
) -> Tuple[str, str, str]:
"""Service used to classify a mushroom shown in an image.
The mushroom is classified to one of many well known mushroom classes/types,
as well as according to its toxicity profile (i.e. edible or poisonous).
Additionally, a probability is returned showing confidence of classification.
:param img: the input image of the mushroom to be classified
:type img: Image.Image
:param model: the pretrained model
:type model: PretrainedModel
:param preprocessor: the auto preprocessor for image transforms (rescales, crops, normalizations etc.)
:type preprocessor: BaseImageProcessor
:raises HTTPException: Internal Server Error
:return: mushroom_type, toxicity_profile, classification_confidence
:rtype: Tuple[str, str, float]
"""
try:
logger.debug("Loading classification model.")
inputs = preprocessor(img, return_tensors="pt").to(model.device)
# Turn on model evaluation mode and inference mode
model.eval()
with torch.inference_mode():
logger.debug("Starting classification process...")
# Make a prediction on image with an extra dimension and send it to the target device
target_image_pred = model(inputs["pixel_values"])
# Convert logits -> prediction probabilities (using torch.softmax() for multi-class classification)
target_image_pred_probs = torch.softmax(target_image_pred, dim=1)
# model predicts one of the 12 potential mushroom classes
predicted_label = target_image_pred.argmax(dim=1).item()
# Get the label/class name of the prediction made using id2label
class_name = model.config.id2label[predicted_label]
# Split class_name to mushroom type and toxicity profile
class_type, toxicity = class_name.rsplit("_", 1)
# 4 decimal points precision
prob = round(target_image_pred_probs.max().item(), 4)
logger.debug("Finished classification process...")
return class_type, toxicity, prob
except Exception as e:
logger.error(f"Classification process error: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Classification process failed due to an internal error. Contact support if this persists.",
)