Dhahlan2000 commited on
Commit
10bb1e5
·
verified ·
1 Parent(s): f4d858f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -0
app.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import ViTForImageClassification, ViTFeatureExtractor
4
+ from PIL import Image
5
+
6
+ # Load model and feature extractor
7
+ model = ViTForImageClassification.from_pretrained('shahmi0519/fypvit', num_labels=30, ignore_mismatched_sizes=True)
8
+ feature_extractor = ViTFeatureExtractor.from_pretrained('shahmi0519/fypvit')
9
+
10
+ # Move to GPU if available
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+ model = model.to(device)
13
+ model.eval()
14
+
15
+ # Class labels (modify according to your model)
16
+ class_labels = [
17
+ "Bellpepper_fresh",
18
+ "Bellpepper_intermediate_fresh",
19
+ "Bellpepper_rotten",
20
+ "Carrot_fresh",
21
+ "Carrot_intermediate_fresh",
22
+ "Carrot_rotten",
23
+ "Cucumber_fresh",
24
+ "Cucumber_intermediate_fresh",
25
+ "Cucumber_rotten",
26
+ "Potato_fresh",
27
+ "Potato_intermediate_fresh",
28
+ "Potato_rotten",
29
+ "Tomato_fresh",
30
+ "Tomato_intermediate_fresh",
31
+ "Tomato_rotten",
32
+ "ripe_apple",
33
+ "ripe_banana",
34
+ "ripe_mango",
35
+ "ripe_oranges",
36
+ "ripe_strawberry",
37
+ "rotten_apple",
38
+ "rotten_banana",
39
+ "rotten_mango",
40
+ "rotten_oranges",
41
+ "rotten_strawberry",
42
+ "unripe_apple",
43
+ "unripe_banana",
44
+ "unripe_mango",
45
+ "unripe_oranges",
46
+ "unripe_strawberry"
47
+ ]
48
+
49
+ def predict_freshness(image):
50
+ # Preprocess image
51
+ inputs = feature_extractor(images=image, return_tensors="pt").to(device)
52
+
53
+ # Predict
54
+ model.eval()
55
+ with torch.no_grad():
56
+ outputs = model(**inputs)
57
+ logits = outputs.logits
58
+ predicted_class_idx = logits.argmax(-1).item()
59
+
60
+ # Get label
61
+ try:
62
+ label = class_labels[predicted_class_idx]
63
+ except IndexError:
64
+ label = f"Class {predicted_class_idx}"
65
+
66
+ return label
67
+
68
+ # Create Gradio interface
69
+ title = "Freshness Detector"
70
+ description = "Upload an image of fruit/vegetable to detect its freshness state"
71
+ examples = [
72
+ ["apple.jpeg"],
73
+ ["banana.jpeg"],
74
+ ["tomato.jpeg"]
75
+ ]
76
+
77
+ iface = gr.Interface(
78
+ fn=predict_freshness,
79
+ inputs=gr.Image(type="pil", label="Upload Image"),
80
+ outputs=gr.Label(label="Freshness State"),
81
+ title=title,
82
+ description=description,
83
+ examples=examples
84
+ )
85
+
86
+ iface.launch(share=True)