import gradio as gr import torch from transformers import ViTForImageClassification, ViTFeatureExtractor from PIL import Image # Load model and feature extractor model = ViTForImageClassification.from_pretrained('shahmi0519/banana_artificial_v2', num_labels=2, ignore_mismatched_sizes=True) feature_extractor = ViTFeatureExtractor.from_pretrained('shahmi0519/banana_artificial_v2') # Move to GPU if available device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) model.eval() # Class labels (modify according to your model) class_labels = [ "Artificial", "Natural" ] def predict_freshness(image): # Preprocess image inputs = feature_extractor(images=image, return_tensors="pt").to(device) # Predict model.eval() with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits predicted_class_idx = logits.argmax(-1).item() # Get label try: label = class_labels[predicted_class_idx] except IndexError: label = f"Class {predicted_class_idx}" return label # Create Gradio interface title = "Freshness Detector" description = "Upload an image of fruit/vegetable to detect its freshness state" examples = [ ["apple.jpeg"], ["banana.jpeg"], ["tomato.jpeg"] ] iface = gr.Interface( fn=predict_freshness, inputs=gr.Image(type="pil", label="Upload Image"), outputs=gr.Label(label="Freshness State"), title=title, description=description, examples=examples ) iface.launch(share=True)