Spaces:
Sleeping
Sleeping
Initial commit
Browse files- .dockerignore +9 -0
- Dockerfile +13 -0
- requirements.txt +8 -0
- src/backup_services.py +91 -0
- src/config.py +1 -0
- src/dependencies.py +15 -0
- src/main.py +47 -0
- src/router.py +59 -0
- src/schema.py +7 -0
- src/services.py +72 -0
.dockerignore
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.git
|
2 |
+
.gitattributes
|
3 |
+
.coverage
|
4 |
+
.coverage.*
|
5 |
+
.env
|
6 |
+
.venv
|
7 |
+
.aws
|
8 |
+
*.log
|
9 |
+
*.md
|
Dockerfile
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ARG APP_PORT
|
2 |
+
FROM python:3.10
|
3 |
+
|
4 |
+
# The two following lines are requirements for the Dev Mode to be functional
|
5 |
+
# Learn more about the Dev Mode at https://huggingface.co/dev-mode-explorers
|
6 |
+
RUN useradd -m -u 1000 user
|
7 |
+
WORKDIR /app
|
8 |
+
|
9 |
+
COPY --chown=user ./requirements.txt requirements.txt
|
10 |
+
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
11 |
+
|
12 |
+
COPY --chown=user ./src /app
|
13 |
+
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", $APP_PORT]
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate==1.8.1
|
2 |
+
fastapi==0.116.0
|
3 |
+
pillow==11.3.0
|
4 |
+
python-multipart==0.0.20
|
5 |
+
torch==2.7.1
|
6 |
+
torchvision==0.22.1
|
7 |
+
transformers==4.53.1
|
8 |
+
uvicorn==0.35.0
|
src/backup_services.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from fastapi import HTTPException, status
|
8 |
+
from PIL import Image
|
9 |
+
from torchvision import models
|
10 |
+
from typing import Tuple
|
11 |
+
|
12 |
+
import src.config as config
|
13 |
+
|
14 |
+
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
|
17 |
+
|
18 |
+
async def classify_mushroom_in_image_svc(img: Image.Image) -> Tuple[str, str, str]:
|
19 |
+
"""Service used to classify a mushroom shown in an image.
|
20 |
+
The mushroom is classified to one of many well known mushroom classes/types,
|
21 |
+
as well as according to its toxicity profile (i.e. edible or poisonous).
|
22 |
+
Additionally, a probability is returned showing confidence of classification.
|
23 |
+
|
24 |
+
:param img: the image of the mushroom to be classified
|
25 |
+
:type img: Image.Image
|
26 |
+
:return: mushroom_type, toxicity_profile, classification_confidence
|
27 |
+
:rtype: Tuple[str, str, str]
|
28 |
+
"""
|
29 |
+
|
30 |
+
try:
|
31 |
+
# Device agnostic
|
32 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
33 |
+
|
34 |
+
logger.debug("Loading classification model.")
|
35 |
+
|
36 |
+
model_path = config.MODEL_PATH
|
37 |
+
|
38 |
+
# Load saved model checkpoint
|
39 |
+
model_state_dict = torch.load(model_path, map_location=device)
|
40 |
+
|
41 |
+
# Get class_names from saved model checkpoint
|
42 |
+
model_dirname = Path(model_path).resolve().parent
|
43 |
+
with open(model_dirname / "labels.txt", "r") as labels_fp:
|
44 |
+
class_names = [line.strip() for line in labels_fp]
|
45 |
+
|
46 |
+
model = models.get_model(config.BASE_MODEL_NAME, num_classes=len(class_names))
|
47 |
+
|
48 |
+
# Load state_dict of saved model
|
49 |
+
model.load_state_dict(model_state_dict)
|
50 |
+
|
51 |
+
weights_enum = models.get_model_weights(config.BASE_MODEL_NAME)
|
52 |
+
|
53 |
+
# Get the model's default transforms
|
54 |
+
image_transform = weights_enum.DEFAULT.transforms()
|
55 |
+
|
56 |
+
# Make sure the model is on the target device
|
57 |
+
model.to(device)
|
58 |
+
|
59 |
+
# Turn on model evaluation mode and inference mode
|
60 |
+
model.eval()
|
61 |
+
with torch.inference_mode():
|
62 |
+
logger.debug("Adapting input image by applying necessary transforms!")
|
63 |
+
# Transform and add an extra dimension to image (model requires samples in [batch_size, color_channels, height, width])
|
64 |
+
transformed_image = image_transform(img).unsqueeze(dim=0)
|
65 |
+
|
66 |
+
# Make a prediction on image with an extra dimension and send it to the target device
|
67 |
+
target_image_pred = model(transformed_image.to(device))
|
68 |
+
|
69 |
+
logger.debug("Starting classification process...")
|
70 |
+
# Convert logits -> prediction probabilities (using torch.softmax() for multi-class classification)
|
71 |
+
target_image_pred_probs = torch.softmax(target_image_pred, dim=1)
|
72 |
+
|
73 |
+
# Convert prediction probabilities -> prediction labels
|
74 |
+
target_image_pred_label = torch.argmax(target_image_pred_probs, dim=1)
|
75 |
+
|
76 |
+
class_name = class_names[target_image_pred_label]
|
77 |
+
|
78 |
+
# Split class_name to mushroom type and toxicity profile
|
79 |
+
class_type, toxicity = class_name.rsplit("_", 1)
|
80 |
+
|
81 |
+
# 4 decimal points precision
|
82 |
+
prob = round(target_image_pred_probs.max().item(), 4)
|
83 |
+
|
84 |
+
return class_type, toxicity, prob
|
85 |
+
|
86 |
+
except Exception as e:
|
87 |
+
logger.error("Classification process error: {e}")
|
88 |
+
raise HTTPException(
|
89 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
90 |
+
detail="Classification process failed due to an internal error. Contact support if this persists.",
|
91 |
+
)
|
src/config.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
MODEL_ID = "blasisd/musheff"
|
src/dependencies.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import Request, HTTPException
|
2 |
+
from transformers import PreTrainedModel
|
3 |
+
from transformers.image_processing_utils import BaseImageProcessor
|
4 |
+
|
5 |
+
|
6 |
+
def get_model(request: Request) -> PreTrainedModel:
|
7 |
+
if not hasattr(request.app.state, "model"):
|
8 |
+
raise HTTPException(status_code=500, detail="Model not loaded")
|
9 |
+
return request.app.state.model
|
10 |
+
|
11 |
+
|
12 |
+
def get_preprocessor(request: Request) -> BaseImageProcessor:
|
13 |
+
if not hasattr(request.app.state, "preprocessor"):
|
14 |
+
raise HTTPException(status_code=500, detail="Preprocessor not loaded")
|
15 |
+
return request.app.state.preprocessor
|
src/main.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from contextlib import asynccontextmanager
|
2 |
+
from fastapi import FastAPI
|
3 |
+
from transformers import (
|
4 |
+
AutoImageProcessor,
|
5 |
+
AutoModel,
|
6 |
+
)
|
7 |
+
|
8 |
+
import src.config as config
|
9 |
+
|
10 |
+
from src.router import router
|
11 |
+
|
12 |
+
|
13 |
+
@asynccontextmanager
|
14 |
+
async def lifespan(app: FastAPI):
|
15 |
+
# Load models during startup
|
16 |
+
|
17 |
+
app.state.model = AutoModel.from_pretrained(
|
18 |
+
config.MODEL_ID,
|
19 |
+
trust_remote_code=True,
|
20 |
+
low_cpu_mem_usage=True, # Activates memory-efficient loading
|
21 |
+
device_map="auto", # Distributes layers across devices
|
22 |
+
)
|
23 |
+
|
24 |
+
app.state.preprocessor = AutoImageProcessor.from_pretrained(
|
25 |
+
config.MODEL_ID,
|
26 |
+
trust_remote_code=True,
|
27 |
+
use_fast=True,
|
28 |
+
)
|
29 |
+
|
30 |
+
yield
|
31 |
+
|
32 |
+
# Cleanup during shutdown (e.g., GPU memory)
|
33 |
+
del app.state.model
|
34 |
+
del app.state.preprocessor
|
35 |
+
|
36 |
+
|
37 |
+
app = FastAPI(
|
38 |
+
description="Mushrooms Classification API", version="0.1.0", lifespan=lifespan
|
39 |
+
)
|
40 |
+
|
41 |
+
|
42 |
+
@app.get("/")
|
43 |
+
async def root():
|
44 |
+
return {"message": "Welcome to Mushrooms Classification API 🍄"}
|
45 |
+
|
46 |
+
|
47 |
+
app.include_router(router)
|
src/router.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import logging
|
3 |
+
|
4 |
+
from fastapi import APIRouter, HTTPException, status, UploadFile, Depends
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
from src.dependencies import get_model, get_preprocessor
|
8 |
+
from src.schema import MushroomClassification
|
9 |
+
from src.services import classify_mushroom_in_image_svc
|
10 |
+
|
11 |
+
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
router = APIRouter()
|
15 |
+
|
16 |
+
|
17 |
+
@router.post(
|
18 |
+
"/classify", response_model=MushroomClassification, status_code=status.HTTP_200_OK
|
19 |
+
)
|
20 |
+
async def classify_mushroom_in_image(
|
21 |
+
image_file: UploadFile,
|
22 |
+
model=Depends(get_model),
|
23 |
+
preprocessor=Depends(get_preprocessor),
|
24 |
+
):
|
25 |
+
"""Open uploaded image file and call mushroom classification
|
26 |
+
service.
|
27 |
+
|
28 |
+
:param image_file: the uploaded image file
|
29 |
+
:type image_file: UploadFile
|
30 |
+
:param model: the pretrained model, defaults to Depends(get_model)
|
31 |
+
:type model: PreTrainedModel, optional
|
32 |
+
:param preprocessor: the preprocessor for image input transforms, defaults to Depends(get_preprocessor)
|
33 |
+
:type preprocessor: BaseImageProcessor, optional
|
34 |
+
:raises HTTPException: Internal Server Error in case of model/preprocessor loading failure or some uknown error,
|
35 |
+
or Bad Request Error in case of corrupted or invalid uploaded file
|
36 |
+
:return: mushroom_type, toxicity_profile, classification_confidence
|
37 |
+
:rtype: MushroomClassification
|
38 |
+
"""
|
39 |
+
logger.info(f"Classify image: {image_file.filename}")
|
40 |
+
|
41 |
+
try:
|
42 |
+
request_object_content = await image_file.read()
|
43 |
+
img = Image.open(io.BytesIO(request_object_content))
|
44 |
+
if img.mode != "RGB":
|
45 |
+
img = img.convert("RGB")
|
46 |
+
except Exception as e:
|
47 |
+
logger.error(f"Error reading file: {str(e)}")
|
48 |
+
raise HTTPException(
|
49 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
50 |
+
detail="Failed to read uploaded file. The file may be corrupted or invalid.",
|
51 |
+
)
|
52 |
+
|
53 |
+
# Get class_name predicted and prediction probability
|
54 |
+
class_name, toxicity, confidence = await classify_mushroom_in_image_svc(
|
55 |
+
img, model, preprocessor
|
56 |
+
)
|
57 |
+
return MushroomClassification(
|
58 |
+
mushroom_type=class_name, toxicity_profile=toxicity, confidence=confidence
|
59 |
+
)
|
src/schema.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel
|
2 |
+
|
3 |
+
|
4 |
+
class MushroomClassification(BaseModel):
|
5 |
+
mushroom_type: str
|
6 |
+
toxicity_profile: str
|
7 |
+
confidence: float
|
src/services.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
from typing import Tuple
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from fastapi import HTTPException, status
|
8 |
+
from PIL import Image
|
9 |
+
from transformers import PreTrainedModel
|
10 |
+
from transformers.image_processing_utils import BaseImageProcessor
|
11 |
+
|
12 |
+
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
|
16 |
+
async def classify_mushroom_in_image_svc(
|
17 |
+
img: Image.Image, model: PreTrainedModel, preprocessor: BaseImageProcessor
|
18 |
+
) -> Tuple[str, str, str]:
|
19 |
+
"""Service used to classify a mushroom shown in an image.
|
20 |
+
The mushroom is classified to one of many well known mushroom classes/types,
|
21 |
+
as well as according to its toxicity profile (i.e. edible or poisonous).
|
22 |
+
Additionally, a probability is returned showing confidence of classification.
|
23 |
+
|
24 |
+
:param img: the input image of the mushroom to be classified
|
25 |
+
:type img: Image.Image
|
26 |
+
:param model: the pretrained model
|
27 |
+
:type model: PretrainedModel
|
28 |
+
:param preprocessor: the auto preprocessor for image transforms (rescales, crops, normalizations etc.)
|
29 |
+
:type preprocessor: BaseImageProcessor
|
30 |
+
:raises HTTPException: Internal Server Error
|
31 |
+
:return: mushroom_type, toxicity_profile, classification_confidence
|
32 |
+
:rtype: Tuple[str, str, float]
|
33 |
+
"""
|
34 |
+
|
35 |
+
try:
|
36 |
+
|
37 |
+
logger.debug("Loading classification model.")
|
38 |
+
|
39 |
+
inputs = preprocessor(img, return_tensors="pt").to(model.device)
|
40 |
+
|
41 |
+
# Turn on model evaluation mode and inference mode
|
42 |
+
model.eval()
|
43 |
+
with torch.inference_mode():
|
44 |
+
logger.debug("Starting classification process...")
|
45 |
+
|
46 |
+
# Make a prediction on image with an extra dimension and send it to the target device
|
47 |
+
target_image_pred = model(inputs["pixel_values"])
|
48 |
+
|
49 |
+
# Convert logits -> prediction probabilities (using torch.softmax() for multi-class classification)
|
50 |
+
target_image_pred_probs = torch.softmax(target_image_pred, dim=1)
|
51 |
+
|
52 |
+
# model predicts one of the 12 potential mushroom classes
|
53 |
+
predicted_label = target_image_pred.argmax(dim=1).item()
|
54 |
+
|
55 |
+
# Get the label/class name of the prediction made using id2label
|
56 |
+
class_name = model.config.id2label[predicted_label]
|
57 |
+
|
58 |
+
# Split class_name to mushroom type and toxicity profile
|
59 |
+
class_type, toxicity = class_name.rsplit("_", 1)
|
60 |
+
|
61 |
+
# 4 decimal points precision
|
62 |
+
prob = round(target_image_pred_probs.max().item(), 4)
|
63 |
+
|
64 |
+
logger.debug("Finished classification process...")
|
65 |
+
return class_type, toxicity, prob
|
66 |
+
|
67 |
+
except Exception as e:
|
68 |
+
logger.error(f"Classification process error: {e}")
|
69 |
+
raise HTTPException(
|
70 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
71 |
+
detail="Classification process failed due to an internal error. Contact support if this persists.",
|
72 |
+
)
|