Spaces:
Sleeping
Sleeping
import logging | |
from typing import Tuple | |
import torch | |
from fastapi import HTTPException, status | |
from PIL import Image | |
from transformers import PreTrainedModel | |
from transformers.image_processing_utils import BaseImageProcessor | |
logger = logging.getLogger(__name__) | |
async def classify_mushroom_in_image_svc( | |
img: Image.Image, model: PreTrainedModel, preprocessor: BaseImageProcessor | |
) -> Tuple[str, str, str]: | |
"""Service used to classify a mushroom shown in an image. | |
The mushroom is classified to one of many well known mushroom classes/types, | |
as well as according to its toxicity profile (i.e. edible or poisonous). | |
Additionally, a probability is returned showing confidence of classification. | |
:param img: the input image of the mushroom to be classified | |
:type img: Image.Image | |
:param model: the pretrained model | |
:type model: PretrainedModel | |
:param preprocessor: the auto preprocessor for image transforms (rescales, crops, normalizations etc.) | |
:type preprocessor: BaseImageProcessor | |
:raises HTTPException: Internal Server Error | |
:return: mushroom_type, toxicity_profile, classification_confidence | |
:rtype: Tuple[str, str, float] | |
""" | |
try: | |
logger.debug("Loading classification model.") | |
inputs = preprocessor(img, return_tensors="pt").to(model.device) | |
# Turn on model evaluation mode and inference mode | |
model.eval() | |
with torch.inference_mode(): | |
logger.debug("Starting classification process...") | |
# Make a prediction on image with an extra dimension and send it to the target device | |
target_image_pred = model(inputs["pixel_values"]) | |
# Convert logits -> prediction probabilities (using torch.softmax() for multi-class classification) | |
target_image_pred_probs = torch.softmax(target_image_pred, dim=1) | |
# model predicts one of the 12 potential mushroom classes | |
predicted_label = target_image_pred.argmax(dim=1).item() | |
# Get the label/class name of the prediction made using id2label | |
class_name = model.config.id2label[predicted_label] | |
# Split class_name to mushroom type and toxicity profile | |
class_type, toxicity = class_name.rsplit("_", 1) | |
# 4 decimal points precision | |
prob = round(target_image_pred_probs.max().item(), 4) | |
logger.debug("Finished classification process...") | |
return class_type, toxicity, prob | |
except Exception as e: | |
logger.error(f"Classification process error: {e}") | |
raise HTTPException( | |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
detail="Classification process failed due to an internal error. Contact support if this persists.", | |
) | |