Anjali04-15's picture
Update app.py
3628c89 verified
import os
os.environ["HF_HOME"] = "/app/.cache/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/app/.cache/huggingface/transformers"
os.environ["HF_DATASETS_CACHE"] = "/app/.cache/huggingface/datasets"
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
from PIL import Image
from io import BytesIO
import torch
import torch.nn.functional as F
from transformers import AutoImageProcessor, AutoModelForImageClassification
app = FastAPI()
@app.get("/")
async def root():
return {"message": "API is running"}
# Load model and processor
model_name = "ivandrian11/fruit-classifier"
processor = AutoImageProcessor.from_pretrained(model_name)
model = AutoModelForImageClassification.from_pretrained(model_name)
model.eval() # set to evaluation mode
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(DEVICE)
VALID_CLASSES = ['apple', 'banana', 'orange', 'tomato', 'bitter gourd', 'capsicum']
CLASS_MAPPING = {
'apple': 'apple',
'banana': 'banana',
'orange': 'orange',
'tomato': 'tomato',
'bitter gourd': 'bitter gourd',
'bitter melon': 'bitter gourd',
'bell pepper': 'capsicum',
'pepper': 'capsicum',
'capsicum': 'capsicum',
'green pepper': 'capsicum',
'red pepper': 'capsicum',
'yellow pepper': 'capsicum',
'granny smith': 'apple',
'fuji apple': 'apple',
'gala apple': 'apple',
'navel orange': 'orange',
'valencia orange': 'orange'
}
def classify_fruit(image: Image.Image) -> str:
inputs = processor(images=image, return_tensors="pt").to(DEVICE)
with torch.no_grad():
outputs = model(**inputs)
probabilities = F.softmax(outputs.logits, dim=-1)
confidence, predicted_idx = torch.max(probabilities, dim=-1)
confidence = confidence.item()
predicted_label = model.config.id2label[predicted_idx.item()].lower()
if confidence < 0.7:
return "unknown"
mapped_class = CLASS_MAPPING.get(predicted_label, None)
if mapped_class:
return mapped_class
for valid_class in VALID_CLASSES:
if valid_class in predicted_label:
return valid_class
return "unknown"
@app.post("/classify")
async def classify_image(file: UploadFile = File(...)):
try:
image_bytes = await file.read()
image = Image.open(BytesIO(image_bytes)).convert("RGB")
result = classify_fruit(image)
return JSONResponse(content={"prediction": result})
except Exception as e:
return JSONResponse(content={"prediction": "unknown", "error": str(e)}, status_code=500)