import os os.environ["HF_HOME"] = "/app/.cache/huggingface" os.environ["TRANSFORMERS_CACHE"] = "/app/.cache/huggingface/transformers" os.environ["HF_DATASETS_CACHE"] = "/app/.cache/huggingface/datasets" from fastapi import FastAPI, File, UploadFile from fastapi.responses import JSONResponse from PIL import Image from io import BytesIO import torch import torch.nn.functional as F from transformers import AutoImageProcessor, AutoModelForImageClassification app = FastAPI() @app.get("/") async def root(): return {"message": "API is running"} # Load model and processor model_name = "ivandrian11/fruit-classifier" processor = AutoImageProcessor.from_pretrained(model_name) model = AutoModelForImageClassification.from_pretrained(model_name) model.eval() # set to evaluation mode DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(DEVICE) VALID_CLASSES = ['apple', 'banana', 'orange', 'tomato', 'bitter gourd', 'capsicum'] CLASS_MAPPING = { 'apple': 'apple', 'banana': 'banana', 'orange': 'orange', 'tomato': 'tomato', 'bitter gourd': 'bitter gourd', 'bitter melon': 'bitter gourd', 'bell pepper': 'capsicum', 'pepper': 'capsicum', 'capsicum': 'capsicum', 'green pepper': 'capsicum', 'red pepper': 'capsicum', 'yellow pepper': 'capsicum', 'granny smith': 'apple', 'fuji apple': 'apple', 'gala apple': 'apple', 'navel orange': 'orange', 'valencia orange': 'orange' } def classify_fruit(image: Image.Image) -> str: inputs = processor(images=image, return_tensors="pt").to(DEVICE) with torch.no_grad(): outputs = model(**inputs) probabilities = F.softmax(outputs.logits, dim=-1) confidence, predicted_idx = torch.max(probabilities, dim=-1) confidence = confidence.item() predicted_label = model.config.id2label[predicted_idx.item()].lower() if confidence < 0.7: return "unknown" mapped_class = CLASS_MAPPING.get(predicted_label, None) if mapped_class: return mapped_class for valid_class in VALID_CLASSES: if valid_class in predicted_label: return valid_class return "unknown" @app.post("/classify") async def classify_image(file: UploadFile = File(...)): try: image_bytes = await file.read() image = Image.open(BytesIO(image_bytes)).convert("RGB") result = classify_fruit(image) return JSONResponse(content={"prediction": result}) except Exception as e: return JSONResponse(content={"prediction": "unknown", "error": str(e)}, status_code=500)