saritha's picture
Update app.py
ae1742e verified
raw
history blame
2.29 kB
import gradio as gr
import torch
from PIL import Image
from torchvision import transforms
import warnings
import sys
import os
import contextlib
from transformers import ViTForImageClassification
# Suppress warnings related to the model weights initialization, FutureWarning and UserWarnings
warnings.filterwarnings("ignore", category=UserWarning, module="transformers")
warnings.filterwarnings("ignore", category=FutureWarning, module="torch")
# Suppress output for copying files and verbose model initialization messages
@contextlib.contextmanager
def suppress_stdout():
with open(os.devnull, 'w') as devnull:
old_stdout = sys.stdout
sys.stdout = devnull
try:
yield
finally:
sys.stdout = old_stdout
# Load the saved model and suppress the warnings
with suppress_stdout():
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=6)
model.load_state_dict(torch.load('vit_sugarcane_disease_detection.pth', map_location=torch.device('cpu')))
model.eval()
# Define the same transformation used during training
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Load the class names (disease types)
class_names = ['BacterialBlights', 'Healthy', 'Mosaic', 'RedRot', 'Rust', 'Yellow']
# Function to predict disease type from an image
def predict_disease(image):
# Apply transformations to the image
img_tensor = transform(image).unsqueeze(0) # Add batch dimension
# Make prediction
with torch.no_grad():
outputs = model(img_tensor)
_, predicted_class = torch.max(outputs.logits, 1)
# Get the predicted label
predicted_label = class_names[predicted_class.item()]
return predicted_label
# Create Gradio interface
inputs = gr.Image(type="pil")
outputs = gr.Text()
EXAMPLES = ["img1.png", "img2.png", "img3.png", "img4.png"]
demo_app = gr.Interface(
fn=predict_disease,
inputs=inputs,
outputs=outputs,
title="Sugarcane Disease Detection",
examples=EXAMPLES,
cache_example=True,
live=True,
theme="huggingface"
)
demo_app.launch(debug=True, enable_queue=True)