Anjali04-15 commited on
Commit
8198a3e
·
verified ·
1 Parent(s): b5111f5

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -0
app.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile
2
+ from fastapi.responses import JSONResponse
3
+ from PIL import Image
4
+ from io import BytesIO
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
8
+
9
+ app = FastAPI()
10
+
11
+ @app.get("/")
12
+ async def root():
13
+ return {"message": "API is running"}
14
+
15
+ # Load model and processor
16
+ MODEL_NAME = "ivandrian11/fruit-classifier"
17
+ processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
18
+ model = AutoModelForImageClassification.from_pretrained(MODEL_NAME)
19
+ model.eval() # set to evaluation mode
20
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+ model.to(DEVICE)
22
+
23
+ VALID_CLASSES = ['apple', 'banana', 'orange', 'tomato', 'bitter gourd', 'capsicum']
24
+
25
+ CLASS_MAPPING = {
26
+ 'apple': 'apple',
27
+ 'banana': 'banana',
28
+ 'orange': 'orange',
29
+ 'tomato': 'tomato',
30
+ 'bitter gourd': 'bitter gourd',
31
+ 'bitter melon': 'bitter gourd',
32
+ 'bell pepper': 'capsicum',
33
+ 'pepper': 'capsicum',
34
+ 'capsicum': 'capsicum',
35
+ 'green pepper': 'capsicum',
36
+ 'red pepper': 'capsicum',
37
+ 'yellow pepper': 'capsicum',
38
+ 'granny smith': 'apple',
39
+ 'fuji apple': 'apple',
40
+ 'gala apple': 'apple',
41
+ 'navel orange': 'orange',
42
+ 'valencia orange': 'orange'
43
+ }
44
+
45
+ def classify_fruit(image: Image.Image) -> str:
46
+ inputs = processor(images=image, return_tensors="pt").to(DEVICE)
47
+ with torch.no_grad():
48
+ outputs = model(**inputs)
49
+ probabilities = F.softmax(outputs.logits, dim=-1)
50
+ confidence, predicted_idx = torch.max(probabilities, dim=-1)
51
+ confidence = confidence.item()
52
+ predicted_label = model.config.id2label[predicted_idx.item()].lower()
53
+
54
+ if confidence < 0.7:
55
+ return "unknown"
56
+
57
+ mapped_class = CLASS_MAPPING.get(predicted_label, None)
58
+ if mapped_class:
59
+ return mapped_class
60
+
61
+ for valid_class in VALID_CLASSES:
62
+ if valid_class in predicted_label:
63
+ return valid_class
64
+
65
+ return "unknown"
66
+
67
+ @app.post("/classify")
68
+ async def classify_image(file: UploadFile = File(...)):
69
+ try:
70
+ image_bytes = await file.read()
71
+ image = Image.open(BytesIO(image_bytes)).convert("RGB")
72
+ result = classify_fruit(image)
73
+ return JSONResponse(content={"prediction": result})
74
+ except Exception as e:
75
+ return JSONResponse(content={"prediction": "unknown", "error": str(e)}, status_code=500)