Anjali04-15 commited on
Commit
b827d72
·
verified ·
1 Parent(s): ffabdb7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -74
app.py CHANGED
@@ -1,75 +1,77 @@
1
- from fastapi import FastAPI, File, UploadFile
2
- from fastapi.responses import JSONResponse
3
- from PIL import Image
4
- from io import BytesIO
5
- import torch
6
- import torch.nn.functional as F
7
- from transformers import AutoImageProcessor, AutoModelForImageClassification
8
-
9
- app = FastAPI()
10
-
11
- @app.get("/")
12
- async def root():
13
- return {"message": "API is running"}
14
-
15
- # Load model and processor
16
- MODEL_NAME = "ivandrian11/fruit-classifier"
17
- processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
18
- model = AutoModelForImageClassification.from_pretrained(MODEL_NAME)
19
- model.eval() # set to evaluation mode
20
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
- model.to(DEVICE)
22
-
23
- VALID_CLASSES = ['apple', 'banana', 'orange', 'tomato', 'bitter gourd', 'capsicum']
24
-
25
- CLASS_MAPPING = {
26
- 'apple': 'apple',
27
- 'banana': 'banana',
28
- 'orange': 'orange',
29
- 'tomato': 'tomato',
30
- 'bitter gourd': 'bitter gourd',
31
- 'bitter melon': 'bitter gourd',
32
- 'bell pepper': 'capsicum',
33
- 'pepper': 'capsicum',
34
- 'capsicum': 'capsicum',
35
- 'green pepper': 'capsicum',
36
- 'red pepper': 'capsicum',
37
- 'yellow pepper': 'capsicum',
38
- 'granny smith': 'apple',
39
- 'fuji apple': 'apple',
40
- 'gala apple': 'apple',
41
- 'navel orange': 'orange',
42
- 'valencia orange': 'orange'
43
- }
44
-
45
- def classify_fruit(image: Image.Image) -> str:
46
- inputs = processor(images=image, return_tensors="pt").to(DEVICE)
47
- with torch.no_grad():
48
- outputs = model(**inputs)
49
- probabilities = F.softmax(outputs.logits, dim=-1)
50
- confidence, predicted_idx = torch.max(probabilities, dim=-1)
51
- confidence = confidence.item()
52
- predicted_label = model.config.id2label[predicted_idx.item()].lower()
53
-
54
- if confidence < 0.7:
55
- return "unknown"
56
-
57
- mapped_class = CLASS_MAPPING.get(predicted_label, None)
58
- if mapped_class:
59
- return mapped_class
60
-
61
- for valid_class in VALID_CLASSES:
62
- if valid_class in predicted_label:
63
- return valid_class
64
-
65
- return "unknown"
66
-
67
- @app.post("/classify")
68
- async def classify_image(file: UploadFile = File(...)):
69
- try:
70
- image_bytes = await file.read()
71
- image = Image.open(BytesIO(image_bytes)).convert("RGB")
72
- result = classify_fruit(image)
73
- return JSONResponse(content={"prediction": result})
74
- except Exception as e:
 
 
75
  return JSONResponse(content={"prediction": "unknown", "error": str(e)}, status_code=500)
 
1
+ import os
2
+ from fastapi import FastAPI, File, UploadFile
3
+ from fastapi.responses import JSONResponse
4
+ from PIL import Image
5
+ from io import BytesIO
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
9
+
10
+ app = FastAPI()
11
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
12
+
13
+ @app.get("/")
14
+ async def root():
15
+ return {"message": "API is running"}
16
+
17
+ # Load model and processor
18
+ MODEL_NAME = "ivandrian11/fruit-classifier"
19
+ processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
20
+ model = AutoModelForImageClassification.from_pretrained(MODEL_NAME)
21
+ model.eval() # set to evaluation mode
22
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+ model.to(DEVICE)
24
+
25
+ VALID_CLASSES = ['apple', 'banana', 'orange', 'tomato', 'bitter gourd', 'capsicum']
26
+
27
+ CLASS_MAPPING = {
28
+ 'apple': 'apple',
29
+ 'banana': 'banana',
30
+ 'orange': 'orange',
31
+ 'tomato': 'tomato',
32
+ 'bitter gourd': 'bitter gourd',
33
+ 'bitter melon': 'bitter gourd',
34
+ 'bell pepper': 'capsicum',
35
+ 'pepper': 'capsicum',
36
+ 'capsicum': 'capsicum',
37
+ 'green pepper': 'capsicum',
38
+ 'red pepper': 'capsicum',
39
+ 'yellow pepper': 'capsicum',
40
+ 'granny smith': 'apple',
41
+ 'fuji apple': 'apple',
42
+ 'gala apple': 'apple',
43
+ 'navel orange': 'orange',
44
+ 'valencia orange': 'orange'
45
+ }
46
+
47
+ def classify_fruit(image: Image.Image) -> str:
48
+ inputs = processor(images=image, return_tensors="pt").to(DEVICE)
49
+ with torch.no_grad():
50
+ outputs = model(**inputs)
51
+ probabilities = F.softmax(outputs.logits, dim=-1)
52
+ confidence, predicted_idx = torch.max(probabilities, dim=-1)
53
+ confidence = confidence.item()
54
+ predicted_label = model.config.id2label[predicted_idx.item()].lower()
55
+
56
+ if confidence < 0.7:
57
+ return "unknown"
58
+
59
+ mapped_class = CLASS_MAPPING.get(predicted_label, None)
60
+ if mapped_class:
61
+ return mapped_class
62
+
63
+ for valid_class in VALID_CLASSES:
64
+ if valid_class in predicted_label:
65
+ return valid_class
66
+
67
+ return "unknown"
68
+
69
+ @app.post("/classify")
70
+ async def classify_image(file: UploadFile = File(...)):
71
+ try:
72
+ image_bytes = await file.read()
73
+ image = Image.open(BytesIO(image_bytes)).convert("RGB")
74
+ result = classify_fruit(image)
75
+ return JSONResponse(content={"prediction": result})
76
+ except Exception as e:
77
  return JSONResponse(content={"prediction": "unknown", "error": str(e)}, status_code=500)