Anjali04-15's picture
Update app.py
7f8df3c verified
raw
history blame
2.4 kB
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
from PIL import Image
from io import BytesIO
import torch
import torch.nn.functional as F
from transformers import AutoImageProcessor, AutoModelForImageClassification
app = FastAPI()
@app.get("/")
async def root():
return {"message": "API is running"}
# Load model and processor
model_name = "ivandrian11/fruit-classifier"
processor = AutoImageProcessor.from_pretrained(model_name)
model = AutoModelForImageClassification.from_pretrained(model_name)
model.eval() # set to evaluation mode
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(DEVICE)
VALID_CLASSES = ['apple', 'banana', 'orange', 'tomato', 'bitter gourd', 'capsicum']
CLASS_MAPPING = {
'apple': 'apple',
'banana': 'banana',
'orange': 'orange',
'tomato': 'tomato',
'bitter gourd': 'bitter gourd',
'bitter melon': 'bitter gourd',
'bell pepper': 'capsicum',
'pepper': 'capsicum',
'capsicum': 'capsicum',
'green pepper': 'capsicum',
'red pepper': 'capsicum',
'yellow pepper': 'capsicum',
'granny smith': 'apple',
'fuji apple': 'apple',
'gala apple': 'apple',
'navel orange': 'orange',
'valencia orange': 'orange'
}
def classify_fruit(image: Image.Image) -> str:
inputs = processor(images=image, return_tensors="pt").to(DEVICE)
with torch.no_grad():
outputs = model(**inputs)
probabilities = F.softmax(outputs.logits, dim=-1)
confidence, predicted_idx = torch.max(probabilities, dim=-1)
confidence = confidence.item()
predicted_label = model.config.id2label[predicted_idx.item()].lower()
if confidence < 0.7:
return "unknown"
mapped_class = CLASS_MAPPING.get(predicted_label, None)
if mapped_class:
return mapped_class
for valid_class in VALID_CLASSES:
if valid_class in predicted_label:
return valid_class
return "unknown"
@app.post("/classify")
async def classify_image(file: UploadFile = File(...)):
try:
image_bytes = await file.read()
image = Image.open(BytesIO(image_bytes)).convert("RGB")
result = classify_fruit(image)
return JSONResponse(content={"prediction": result})
except Exception as e:
return JSONResponse(content={"prediction": "unknown", "error": str(e)}, status_code=500)