|
import gradio as gr |
|
import torch |
|
from transformers import ViTForImageClassification, ViTFeatureExtractor |
|
from PIL import Image |
|
|
|
|
|
model = ViTForImageClassification.from_pretrained('shahmi0519/banana_artificial_v2', num_labels=2, ignore_mismatched_sizes=True) |
|
feature_extractor = ViTFeatureExtractor.from_pretrained('shahmi0519/banana_artificial_v2') |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model = model.to(device) |
|
model.eval() |
|
|
|
|
|
class_labels = [ |
|
"Artificial", |
|
"Natural" |
|
] |
|
|
|
def predict_freshness(image): |
|
|
|
inputs = feature_extractor(images=image, return_tensors="pt").to(device) |
|
|
|
|
|
model.eval() |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
logits = outputs.logits |
|
predicted_class_idx = logits.argmax(-1).item() |
|
|
|
|
|
try: |
|
label = class_labels[predicted_class_idx] |
|
except IndexError: |
|
label = f"Class {predicted_class_idx}" |
|
|
|
return label |
|
|
|
|
|
title = "Freshness Detector" |
|
description = "Upload an image of fruit/vegetable to detect its freshness state" |
|
examples = [ |
|
["apple.jpeg"], |
|
["banana.jpeg"], |
|
["tomato.jpeg"] |
|
] |
|
|
|
iface = gr.Interface( |
|
fn=predict_freshness, |
|
inputs=gr.Image(type="pil", label="Upload Image"), |
|
outputs=gr.Label(label="Freshness State"), |
|
title=title, |
|
description=description, |
|
examples=examples |
|
) |
|
|
|
iface.launch(share=True) |