import gradio as gr import torch from transformers import ViTForImageClassification, ViTFeatureExtractor from PIL import Image # Load model and feature extractor model = ViTForImageClassification.from_pretrained('shahmi0519/fypvit', num_labels=30, ignore_mismatched_sizes=True) feature_extractor = ViTFeatureExtractor.from_pretrained('shahmi0519/fypvit') # Move to GPU if available device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) model.eval() # Class labels (modify according to your model) class_labels = [ "Bellpepper_fresh", "Bellpepper_intermediate_fresh", "Bellpepper_rotten", "Carrot_fresh", "Carrot_intermediate_fresh", "Carrot_rotten", "Cucumber_fresh", "Cucumber_intermediate_fresh", "Cucumber_rotten", "Potato_fresh", "Potato_intermediate_fresh", "Potato_rotten", "Tomato_fresh", "Tomato_intermediate_fresh", "Tomato_rotten", "ripe_apple", "ripe_banana", "ripe_mango", "ripe_oranges", "ripe_strawberry", "rotten_apple", "rotten_banana", "rotten_mango", "rotten_oranges", "rotten_strawberry", "unripe_apple", "unripe_banana", "unripe_mango", "unripe_oranges", "unripe_strawberry" ] def predict_freshness(image): # Preprocess image inputs = feature_extractor(images=image, return_tensors="pt").to(device) # Predict model.eval() with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits predicted_class_idx = logits.argmax(-1).item() # Get label try: label = class_labels[predicted_class_idx] except IndexError: label = f"Class {predicted_class_idx}" return label # Create Gradio interface title = "Freshness Detector" description = "Upload an image of fruit/vegetable to detect its freshness state" examples = [ ["apple.jpeg"], ["banana.jpeg"], ["tomato.jpeg"] ] iface = gr.Interface( fn=predict_freshness, inputs=gr.Image(type="pil", label="Upload Image"), outputs=gr.Label(label="Freshness State"), title=title, description=description, examples=examples ) iface.launch(share=True)