File size: 2,143 Bytes
9e35b9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import io
import logging

from fastapi import APIRouter, HTTPException, status, UploadFile, Depends
from PIL import Image

from src.dependencies import get_model, get_preprocessor
from src.schema import MushroomClassification
from src.services import classify_mushroom_in_image_svc


logger = logging.getLogger(__name__)

router = APIRouter()


@router.post(
    "/classify", response_model=MushroomClassification, status_code=status.HTTP_200_OK
)
async def classify_mushroom_in_image(
    image_file: UploadFile,
    model=Depends(get_model),
    preprocessor=Depends(get_preprocessor),
):
    """Open uploaded image file and call mushroom classification
    service.

    :param image_file: the uploaded image file
    :type image_file: UploadFile
    :param model: the pretrained model, defaults to Depends(get_model)
    :type model: PreTrainedModel, optional
    :param preprocessor: the preprocessor for image input transforms, defaults to Depends(get_preprocessor)
    :type preprocessor: BaseImageProcessor, optional
    :raises HTTPException: Internal Server Error in case of model/preprocessor loading failure or some uknown error,
    or Bad Request Error in case of corrupted or invalid uploaded file
    :return: mushroom_type, toxicity_profile, classification_confidence
    :rtype: MushroomClassification
    """
    logger.info(f"Classify image: {image_file.filename}")

    try:
        request_object_content = await image_file.read()
        img = Image.open(io.BytesIO(request_object_content))
        if img.mode != "RGB":
            img = img.convert("RGB")
    except Exception as e:
        logger.error(f"Error reading file: {str(e)}")
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail="Failed to read uploaded file. The file may be corrupted or invalid.",
        )

    # Get class_name predicted and prediction probability
    class_name, toxicity, confidence = await classify_mushroom_in_image_svc(
        img, model, preprocessor
    )
    return MushroomClassification(
        mushroom_type=class_name, toxicity_profile=toxicity, confidence=confidence
    )