import gradio as gr import torch from PIL import Image from torchvision import transforms import warnings import sys import os import contextlib from transformers import ViTForImageClassification # Suppress warnings related to the model weights initialization warnings.filterwarnings("ignore", category=UserWarning, message=".*weights.*") warnings.filterwarnings("ignore", category=FutureWarning, module="torch") # Suppress output for copying files and verbose model initialization messages @contextlib.contextmanager def suppress_stdout(): with open(os.devnull, 'w') as devnull: old_stdout = sys.stdout sys.stdout = devnull try: yield finally: sys.stdout = old_stdout # Load the saved model and suppress the warnings with suppress_stdout(): model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=6) model.load_state_dict(torch.load('vit_sugarcane_disease_detection.pth', map_location=torch.device('cpu'))) model.eval() # Define the same transformation used during training transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # Load the class names (disease types) class_names = ['BacterialBlights', 'Healthy', 'Mosaic', 'RedRot', 'Rust', 'Yellow'] # Function to predict disease type from an image def predict_disease(image): # Apply transformations to the image img_tensor = transform(image).unsqueeze(0) # Add batch dimension # Make prediction with torch.no_grad(): outputs = model(img_tensor) _, predicted_class = torch.max(outputs.logits, 1) # Get the predicted label predicted_label = class_names[predicted_class.item()] return predicted_label # Create Gradio interface inputs = gr.Image(type="pil") outputs = gr.Text() EXAMPLES = ["bacterialblight.jpeg"] demo_app = gr.Interface( fn=predict_disease, inputs=inputs, outputs=outputs, title="Sugarcane Disease Detection", examples=EXAMPLES, live=True, theme="huggingface" ) demo_app.launch(debug=True)