Dhahlan2000 commited on
Commit
f3fe718
·
verified ·
1 Parent(s): 86853dc

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -0
app.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import ViTForImageClassification, ViTFeatureExtractor
4
+ from PIL import Image
5
+
6
+ # Load model and feature extractor
7
+ model = ViTForImageClassification.from_pretrained('Dhahlan2000/freshness_detector_updated', num_labels=30, ignore_mismatched_sizes=True)
8
+ feature_extractor = ViTFeatureExtractor.from_pretrained('Dhahlan2000/freshness_detector_updated')
9
+
10
+ # Move to GPU if available
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+ model = model.to(device)
13
+
14
+ # Class labels (modify according to your model)
15
+ class_labels = [
16
+ "Overripe", "Ripe", "Rotten", "Unripe",
17
+ # Add all 30 class labels here
18
+ ]
19
+
20
+ def predict_freshness(image):
21
+ # Preprocess image
22
+ inputs = feature_extractor(images=image, return_tensors="pt").to(device)
23
+
24
+ # Predict
25
+ model.eval()
26
+ with torch.no_grad():
27
+ outputs = model(**inputs)
28
+ logits = outputs.logits
29
+ predicted_class_idx = logits.argmax(-1).item()
30
+
31
+ # Get label
32
+ try:
33
+ label = class_labels[predicted_class_idx]
34
+ except IndexError:
35
+ label = f"Class {predicted_class_idx}"
36
+
37
+ return label
38
+
39
+ # Create Gradio interface
40
+ title = "Freshness Detector"
41
+ description = "Upload an image of fruit/vegetable to detect its freshness state"
42
+ examples = [
43
+ ["apple.jpg"],
44
+ ["banana.jpg"],
45
+ ["tomato.jpg"]
46
+ ]
47
+
48
+ iface = gr.Interface(
49
+ fn=predict_freshness,
50
+ inputs=gr.Image(type="pil", label="Upload Image"),
51
+ outputs=gr.Label(label="Freshness State"),
52
+ title=title,
53
+ description=description,
54
+ examples=examples
55
+ )
56
+
57
+ iface.launch()