|
import gradio as gr |
|
import torch |
|
from transformers import ViTForImageClassification, ViTFeatureExtractor |
|
from PIL import Image |
|
|
|
|
|
model = ViTForImageClassification.from_pretrained('shahmi0519/fypvit', num_labels=30, ignore_mismatched_sizes=True) |
|
feature_extractor = ViTFeatureExtractor.from_pretrained('shahmi0519/fypvit') |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model = model.to(device) |
|
model.eval() |
|
|
|
|
|
class_labels = [ |
|
"Bellpepper_fresh", |
|
"Bellpepper_intermediate_fresh", |
|
"Bellpepper_rotten", |
|
"Carrot_fresh", |
|
"Carrot_intermediate_fresh", |
|
"Carrot_rotten", |
|
"Cucumber_fresh", |
|
"Cucumber_intermediate_fresh", |
|
"Cucumber_rotten", |
|
"Potato_fresh", |
|
"Potato_intermediate_fresh", |
|
"Potato_rotten", |
|
"Tomato_fresh", |
|
"Tomato_intermediate_fresh", |
|
"Tomato_rotten", |
|
"ripe_apple", |
|
"ripe_banana", |
|
"ripe_mango", |
|
"ripe_oranges", |
|
"ripe_strawberry", |
|
"rotten_apple", |
|
"rotten_banana", |
|
"rotten_mango", |
|
"rotten_oranges", |
|
"rotten_strawberry", |
|
"unripe_apple", |
|
"unripe_banana", |
|
"unripe_mango", |
|
"unripe_oranges", |
|
"unripe_strawberry" |
|
] |
|
|
|
def predict_freshness(image): |
|
|
|
inputs = feature_extractor(images=image, return_tensors="pt").to(device) |
|
|
|
|
|
model.eval() |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
logits = outputs.logits |
|
predicted_class_idx = logits.argmax(-1).item() |
|
|
|
|
|
try: |
|
label = class_labels[predicted_class_idx] |
|
except IndexError: |
|
label = f"Class {predicted_class_idx}" |
|
|
|
return label |
|
|
|
|
|
title = "Freshness Detector" |
|
description = "Upload an image of fruit/vegetable to detect its freshness state" |
|
examples = [ |
|
["apple.jpeg"], |
|
["banana.jpeg"], |
|
["tomato.jpeg"] |
|
] |
|
|
|
iface = gr.Interface( |
|
fn=predict_freshness, |
|
inputs=gr.Image(type="pil", label="Upload Image"), |
|
outputs=gr.Label(label="Freshness State"), |
|
title=title, |
|
description=description, |
|
examples=examples |
|
) |
|
|
|
iface.launch(share=True) |