Spaces:
Sleeping
Sleeping
import logging | |
from pathlib import Path | |
import torch | |
from fastapi import HTTPException, status | |
from PIL import Image | |
from torchvision import models | |
from typing import Tuple | |
import src.config as config | |
logger = logging.getLogger(__name__) | |
async def classify_mushroom_in_image_svc(img: Image.Image) -> Tuple[str, str, str]: | |
"""Service used to classify a mushroom shown in an image. | |
The mushroom is classified to one of many well known mushroom classes/types, | |
as well as according to its toxicity profile (i.e. edible or poisonous). | |
Additionally, a probability is returned showing confidence of classification. | |
:param img: the image of the mushroom to be classified | |
:type img: Image.Image | |
:return: mushroom_type, toxicity_profile, classification_confidence | |
:rtype: Tuple[str, str, str] | |
""" | |
try: | |
# Device agnostic | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
logger.debug("Loading classification model.") | |
model_path = config.MODEL_PATH | |
# Load saved model checkpoint | |
model_state_dict = torch.load(model_path, map_location=device) | |
# Get class_names from saved model checkpoint | |
model_dirname = Path(model_path).resolve().parent | |
with open(model_dirname / "labels.txt", "r") as labels_fp: | |
class_names = [line.strip() for line in labels_fp] | |
model = models.get_model(config.BASE_MODEL_NAME, num_classes=len(class_names)) | |
# Load state_dict of saved model | |
model.load_state_dict(model_state_dict) | |
weights_enum = models.get_model_weights(config.BASE_MODEL_NAME) | |
# Get the model's default transforms | |
image_transform = weights_enum.DEFAULT.transforms() | |
# Make sure the model is on the target device | |
model.to(device) | |
# Turn on model evaluation mode and inference mode | |
model.eval() | |
with torch.inference_mode(): | |
logger.debug("Adapting input image by applying necessary transforms!") | |
# Transform and add an extra dimension to image (model requires samples in [batch_size, color_channels, height, width]) | |
transformed_image = image_transform(img).unsqueeze(dim=0) | |
# Make a prediction on image with an extra dimension and send it to the target device | |
target_image_pred = model(transformed_image.to(device)) | |
logger.debug("Starting classification process...") | |
# Convert logits -> prediction probabilities (using torch.softmax() for multi-class classification) | |
target_image_pred_probs = torch.softmax(target_image_pred, dim=1) | |
# Convert prediction probabilities -> prediction labels | |
target_image_pred_label = torch.argmax(target_image_pred_probs, dim=1) | |
class_name = class_names[target_image_pred_label] | |
# Split class_name to mushroom type and toxicity profile | |
class_type, toxicity = class_name.rsplit("_", 1) | |
# 4 decimal points precision | |
prob = round(target_image_pred_probs.max().item(), 4) | |
return class_type, toxicity, prob | |
except Exception as e: | |
logger.error("Classification process error: {e}") | |
raise HTTPException( | |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
detail="Classification process failed due to an internal error. Contact support if this persists.", | |
) | |