File size: 2,606 Bytes
3628c89
 
 
 
 
 
b827d72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49876eb
7f8df3c
 
b827d72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8198a3e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os

os.environ["HF_HOME"] = "/app/.cache/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/app/.cache/huggingface/transformers"
os.environ["HF_DATASETS_CACHE"] = "/app/.cache/huggingface/datasets"

from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
from PIL import Image
from io import BytesIO
import torch
import torch.nn.functional as F
from transformers import AutoImageProcessor, AutoModelForImageClassification

app = FastAPI()

@app.get("/")
async def root():
    return {"message": "API is running"}

# Load model and processor
model_name = "ivandrian11/fruit-classifier"
processor = AutoImageProcessor.from_pretrained(model_name)
model = AutoModelForImageClassification.from_pretrained(model_name)
model.eval()  # set to evaluation mode
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(DEVICE)

VALID_CLASSES = ['apple', 'banana', 'orange', 'tomato', 'bitter gourd', 'capsicum']

CLASS_MAPPING = {
    'apple': 'apple',
    'banana': 'banana',
    'orange': 'orange',
    'tomato': 'tomato',
    'bitter gourd': 'bitter gourd',
    'bitter melon': 'bitter gourd',
    'bell pepper': 'capsicum',
    'pepper': 'capsicum',
    'capsicum': 'capsicum',
    'green pepper': 'capsicum',
    'red pepper': 'capsicum',
    'yellow pepper': 'capsicum',
    'granny smith': 'apple',
    'fuji apple': 'apple',
    'gala apple': 'apple',
    'navel orange': 'orange',
    'valencia orange': 'orange'
}

def classify_fruit(image: Image.Image) -> str:
    inputs = processor(images=image, return_tensors="pt").to(DEVICE)
    with torch.no_grad():
        outputs = model(**inputs)
        probabilities = F.softmax(outputs.logits, dim=-1)
        confidence, predicted_idx = torch.max(probabilities, dim=-1)
        confidence = confidence.item()
        predicted_label = model.config.id2label[predicted_idx.item()].lower()

    if confidence < 0.7:
        return "unknown"

    mapped_class = CLASS_MAPPING.get(predicted_label, None)
    if mapped_class:
        return mapped_class

    for valid_class in VALID_CLASSES:
        if valid_class in predicted_label:
            return valid_class

    return "unknown"

@app.post("/classify")
async def classify_image(file: UploadFile = File(...)):
    try:
        image_bytes = await file.read()
        image = Image.open(BytesIO(image_bytes)).convert("RGB")
        result = classify_fruit(image)
        return JSONResponse(content={"prediction": result})
    except Exception as e:
        return JSONResponse(content={"prediction": "unknown", "error": str(e)}, status_code=500)