blasisd commited on
Commit
9e35b9e
·
1 Parent(s): fa7a439

Initial commit

Browse files
.dockerignore ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ .git
2
+ .gitattributes
3
+ .coverage
4
+ .coverage.*
5
+ .env
6
+ .venv
7
+ .aws
8
+ *.log
9
+ *.md
Dockerfile ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ARG APP_PORT
2
+ FROM python:3.10
3
+
4
+ # The two following lines are requirements for the Dev Mode to be functional
5
+ # Learn more about the Dev Mode at https://huggingface.co/dev-mode-explorers
6
+ RUN useradd -m -u 1000 user
7
+ WORKDIR /app
8
+
9
+ COPY --chown=user ./requirements.txt requirements.txt
10
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
11
+
12
+ COPY --chown=user ./src /app
13
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", $APP_PORT]
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ accelerate==1.8.1
2
+ fastapi==0.116.0
3
+ pillow==11.3.0
4
+ python-multipart==0.0.20
5
+ torch==2.7.1
6
+ torchvision==0.22.1
7
+ transformers==4.53.1
8
+ uvicorn==0.35.0
src/backup_services.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ from pathlib import Path
4
+
5
+ import torch
6
+
7
+ from fastapi import HTTPException, status
8
+ from PIL import Image
9
+ from torchvision import models
10
+ from typing import Tuple
11
+
12
+ import src.config as config
13
+
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ async def classify_mushroom_in_image_svc(img: Image.Image) -> Tuple[str, str, str]:
19
+ """Service used to classify a mushroom shown in an image.
20
+ The mushroom is classified to one of many well known mushroom classes/types,
21
+ as well as according to its toxicity profile (i.e. edible or poisonous).
22
+ Additionally, a probability is returned showing confidence of classification.
23
+
24
+ :param img: the image of the mushroom to be classified
25
+ :type img: Image.Image
26
+ :return: mushroom_type, toxicity_profile, classification_confidence
27
+ :rtype: Tuple[str, str, str]
28
+ """
29
+
30
+ try:
31
+ # Device agnostic
32
+ device = "cuda" if torch.cuda.is_available() else "cpu"
33
+
34
+ logger.debug("Loading classification model.")
35
+
36
+ model_path = config.MODEL_PATH
37
+
38
+ # Load saved model checkpoint
39
+ model_state_dict = torch.load(model_path, map_location=device)
40
+
41
+ # Get class_names from saved model checkpoint
42
+ model_dirname = Path(model_path).resolve().parent
43
+ with open(model_dirname / "labels.txt", "r") as labels_fp:
44
+ class_names = [line.strip() for line in labels_fp]
45
+
46
+ model = models.get_model(config.BASE_MODEL_NAME, num_classes=len(class_names))
47
+
48
+ # Load state_dict of saved model
49
+ model.load_state_dict(model_state_dict)
50
+
51
+ weights_enum = models.get_model_weights(config.BASE_MODEL_NAME)
52
+
53
+ # Get the model's default transforms
54
+ image_transform = weights_enum.DEFAULT.transforms()
55
+
56
+ # Make sure the model is on the target device
57
+ model.to(device)
58
+
59
+ # Turn on model evaluation mode and inference mode
60
+ model.eval()
61
+ with torch.inference_mode():
62
+ logger.debug("Adapting input image by applying necessary transforms!")
63
+ # Transform and add an extra dimension to image (model requires samples in [batch_size, color_channels, height, width])
64
+ transformed_image = image_transform(img).unsqueeze(dim=0)
65
+
66
+ # Make a prediction on image with an extra dimension and send it to the target device
67
+ target_image_pred = model(transformed_image.to(device))
68
+
69
+ logger.debug("Starting classification process...")
70
+ # Convert logits -> prediction probabilities (using torch.softmax() for multi-class classification)
71
+ target_image_pred_probs = torch.softmax(target_image_pred, dim=1)
72
+
73
+ # Convert prediction probabilities -> prediction labels
74
+ target_image_pred_label = torch.argmax(target_image_pred_probs, dim=1)
75
+
76
+ class_name = class_names[target_image_pred_label]
77
+
78
+ # Split class_name to mushroom type and toxicity profile
79
+ class_type, toxicity = class_name.rsplit("_", 1)
80
+
81
+ # 4 decimal points precision
82
+ prob = round(target_image_pred_probs.max().item(), 4)
83
+
84
+ return class_type, toxicity, prob
85
+
86
+ except Exception as e:
87
+ logger.error("Classification process error: {e}")
88
+ raise HTTPException(
89
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
90
+ detail="Classification process failed due to an internal error. Contact support if this persists.",
91
+ )
src/config.py ADDED
@@ -0,0 +1 @@
 
 
1
+ MODEL_ID = "blasisd/musheff"
src/dependencies.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import Request, HTTPException
2
+ from transformers import PreTrainedModel
3
+ from transformers.image_processing_utils import BaseImageProcessor
4
+
5
+
6
+ def get_model(request: Request) -> PreTrainedModel:
7
+ if not hasattr(request.app.state, "model"):
8
+ raise HTTPException(status_code=500, detail="Model not loaded")
9
+ return request.app.state.model
10
+
11
+
12
+ def get_preprocessor(request: Request) -> BaseImageProcessor:
13
+ if not hasattr(request.app.state, "preprocessor"):
14
+ raise HTTPException(status_code=500, detail="Preprocessor not loaded")
15
+ return request.app.state.preprocessor
src/main.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import asynccontextmanager
2
+ from fastapi import FastAPI
3
+ from transformers import (
4
+ AutoImageProcessor,
5
+ AutoModel,
6
+ )
7
+
8
+ import src.config as config
9
+
10
+ from src.router import router
11
+
12
+
13
+ @asynccontextmanager
14
+ async def lifespan(app: FastAPI):
15
+ # Load models during startup
16
+
17
+ app.state.model = AutoModel.from_pretrained(
18
+ config.MODEL_ID,
19
+ trust_remote_code=True,
20
+ low_cpu_mem_usage=True, # Activates memory-efficient loading
21
+ device_map="auto", # Distributes layers across devices
22
+ )
23
+
24
+ app.state.preprocessor = AutoImageProcessor.from_pretrained(
25
+ config.MODEL_ID,
26
+ trust_remote_code=True,
27
+ use_fast=True,
28
+ )
29
+
30
+ yield
31
+
32
+ # Cleanup during shutdown (e.g., GPU memory)
33
+ del app.state.model
34
+ del app.state.preprocessor
35
+
36
+
37
+ app = FastAPI(
38
+ description="Mushrooms Classification API", version="0.1.0", lifespan=lifespan
39
+ )
40
+
41
+
42
+ @app.get("/")
43
+ async def root():
44
+ return {"message": "Welcome to Mushrooms Classification API 🍄"}
45
+
46
+
47
+ app.include_router(router)
src/router.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import logging
3
+
4
+ from fastapi import APIRouter, HTTPException, status, UploadFile, Depends
5
+ from PIL import Image
6
+
7
+ from src.dependencies import get_model, get_preprocessor
8
+ from src.schema import MushroomClassification
9
+ from src.services import classify_mushroom_in_image_svc
10
+
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ router = APIRouter()
15
+
16
+
17
+ @router.post(
18
+ "/classify", response_model=MushroomClassification, status_code=status.HTTP_200_OK
19
+ )
20
+ async def classify_mushroom_in_image(
21
+ image_file: UploadFile,
22
+ model=Depends(get_model),
23
+ preprocessor=Depends(get_preprocessor),
24
+ ):
25
+ """Open uploaded image file and call mushroom classification
26
+ service.
27
+
28
+ :param image_file: the uploaded image file
29
+ :type image_file: UploadFile
30
+ :param model: the pretrained model, defaults to Depends(get_model)
31
+ :type model: PreTrainedModel, optional
32
+ :param preprocessor: the preprocessor for image input transforms, defaults to Depends(get_preprocessor)
33
+ :type preprocessor: BaseImageProcessor, optional
34
+ :raises HTTPException: Internal Server Error in case of model/preprocessor loading failure or some uknown error,
35
+ or Bad Request Error in case of corrupted or invalid uploaded file
36
+ :return: mushroom_type, toxicity_profile, classification_confidence
37
+ :rtype: MushroomClassification
38
+ """
39
+ logger.info(f"Classify image: {image_file.filename}")
40
+
41
+ try:
42
+ request_object_content = await image_file.read()
43
+ img = Image.open(io.BytesIO(request_object_content))
44
+ if img.mode != "RGB":
45
+ img = img.convert("RGB")
46
+ except Exception as e:
47
+ logger.error(f"Error reading file: {str(e)}")
48
+ raise HTTPException(
49
+ status_code=status.HTTP_400_BAD_REQUEST,
50
+ detail="Failed to read uploaded file. The file may be corrupted or invalid.",
51
+ )
52
+
53
+ # Get class_name predicted and prediction probability
54
+ class_name, toxicity, confidence = await classify_mushroom_in_image_svc(
55
+ img, model, preprocessor
56
+ )
57
+ return MushroomClassification(
58
+ mushroom_type=class_name, toxicity_profile=toxicity, confidence=confidence
59
+ )
src/schema.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+
3
+
4
+ class MushroomClassification(BaseModel):
5
+ mushroom_type: str
6
+ toxicity_profile: str
7
+ confidence: float
src/services.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ from typing import Tuple
4
+
5
+ import torch
6
+
7
+ from fastapi import HTTPException, status
8
+ from PIL import Image
9
+ from transformers import PreTrainedModel
10
+ from transformers.image_processing_utils import BaseImageProcessor
11
+
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ async def classify_mushroom_in_image_svc(
17
+ img: Image.Image, model: PreTrainedModel, preprocessor: BaseImageProcessor
18
+ ) -> Tuple[str, str, str]:
19
+ """Service used to classify a mushroom shown in an image.
20
+ The mushroom is classified to one of many well known mushroom classes/types,
21
+ as well as according to its toxicity profile (i.e. edible or poisonous).
22
+ Additionally, a probability is returned showing confidence of classification.
23
+
24
+ :param img: the input image of the mushroom to be classified
25
+ :type img: Image.Image
26
+ :param model: the pretrained model
27
+ :type model: PretrainedModel
28
+ :param preprocessor: the auto preprocessor for image transforms (rescales, crops, normalizations etc.)
29
+ :type preprocessor: BaseImageProcessor
30
+ :raises HTTPException: Internal Server Error
31
+ :return: mushroom_type, toxicity_profile, classification_confidence
32
+ :rtype: Tuple[str, str, float]
33
+ """
34
+
35
+ try:
36
+
37
+ logger.debug("Loading classification model.")
38
+
39
+ inputs = preprocessor(img, return_tensors="pt").to(model.device)
40
+
41
+ # Turn on model evaluation mode and inference mode
42
+ model.eval()
43
+ with torch.inference_mode():
44
+ logger.debug("Starting classification process...")
45
+
46
+ # Make a prediction on image with an extra dimension and send it to the target device
47
+ target_image_pred = model(inputs["pixel_values"])
48
+
49
+ # Convert logits -> prediction probabilities (using torch.softmax() for multi-class classification)
50
+ target_image_pred_probs = torch.softmax(target_image_pred, dim=1)
51
+
52
+ # model predicts one of the 12 potential mushroom classes
53
+ predicted_label = target_image_pred.argmax(dim=1).item()
54
+
55
+ # Get the label/class name of the prediction made using id2label
56
+ class_name = model.config.id2label[predicted_label]
57
+
58
+ # Split class_name to mushroom type and toxicity profile
59
+ class_type, toxicity = class_name.rsplit("_", 1)
60
+
61
+ # 4 decimal points precision
62
+ prob = round(target_image_pred_probs.max().item(), 4)
63
+
64
+ logger.debug("Finished classification process...")
65
+ return class_type, toxicity, prob
66
+
67
+ except Exception as e:
68
+ logger.error(f"Classification process error: {e}")
69
+ raise HTTPException(
70
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
71
+ detail="Classification process failed due to an internal error. Contact support if this persists.",
72
+ )