Dhahlan2000 commited on
Commit
9837266
·
verified ·
1 Parent(s): deaf075

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -0
app.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import ViTForImageClassification, ViTFeatureExtractor
3
+ import gradio as gr
4
+ from PIL import Image
5
+
6
+ # Check if GPU is available
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+
9
+ # Load pre-trained ViT model from Hugging Face
10
+ model = ViTForImageClassification.from_pretrained('Dhahlan2000/banana_ripeness_level_detection', num_labels=20)
11
+ model.to(device)
12
+ model.eval()
13
+
14
+ # Load ViT feature extractor
15
+ feature_extractor = ViTFeatureExtractor.from_pretrained('Dhahlan2000/banana_ripeness_level_detection')
16
+
17
+ # Class labels
18
+ predicted_classes = ['Overripe', 'ripe', 'rotten', 'unripe']
19
+
20
+ # Function for inference
21
+ def classify_fruit(image):
22
+ inputs = feature_extractor(images=image, return_tensors="pt").to(device)
23
+ with torch.no_grad():
24
+ outputs = model(**inputs)
25
+ logits = outputs.logits
26
+ predicted_class_idx = logits.argmax(-1).item()
27
+ return predicted_classes[predicted_class_idx]
28
+
29
+ # Gradio UI
30
+ demo = gr.Interface(
31
+ fn=classify_fruit,
32
+ inputs=gr.Image(type="pil"),
33
+ outputs=gr.Label(),
34
+ title="Fruit Ripeness Detection",
35
+ description="Upload an image of a fruit to determine whether it's fresh or rotten."
36
+ )
37
+
38
+ demo.launch()