Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from PIL import Image | |
from torchvision import transforms | |
import warnings | |
import sys | |
import os | |
import contextlib | |
from transformers import ViTForImageClassification | |
# Suppress warnings related to the model weights initialization | |
warnings.filterwarnings("ignore", category=UserWarning, message=".*weights.*") | |
warnings.filterwarnings("ignore", category=FutureWarning, module="torch") | |
# Suppress output for copying files and verbose model initialization messages | |
def suppress_stdout(): | |
with open(os.devnull, 'w') as devnull: | |
old_stdout = sys.stdout | |
sys.stdout = devnull | |
try: | |
yield | |
finally: | |
sys.stdout = old_stdout | |
# Load the saved model and suppress the warnings | |
with suppress_stdout(): | |
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=6) | |
model.load_state_dict(torch.load('vit_sugarcane_disease_detection.pth', map_location=torch.device('cpu'))) | |
model.eval() | |
# Define the same transformation used during training | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
]) | |
# Load the class names (disease types) | |
class_names = ['BacterialBlights', 'Healthy', 'Mosaic', 'RedRot', 'Rust', 'Yellow'] | |
# Function to predict disease type from an image | |
def predict_disease(image): | |
# Apply transformations to the image | |
img_tensor = transform(image).unsqueeze(0) # Add batch dimension | |
# Make prediction | |
with torch.no_grad(): | |
outputs = model(img_tensor) | |
_, predicted_class = torch.max(outputs.logits, 1) | |
# Get the predicted label | |
predicted_label = class_names[predicted_class.item()] | |
# Create a styled HTML output | |
output_message = f""" | |
<div style='font-size: 18px; color: #4CAF50; font-weight: bold;'> | |
Detected Disease: <span style='color: #FF5722;'>{predicted_label}</span> | |
</div> | |
""" | |
if predicted_label != "Healthy": | |
output_message += f""" | |
<p style='font-size: 16px; color: #757575;'> | |
This indicates the presence of <strong>{predicted_label}</strong>. Please take immediate action to prevent further spread. | |
</p> | |
""" | |
else: | |
output_message += f""" | |
<p style='font-size: 16px; color: #757575;'> | |
The sugarcane crop is <strong>healthy</strong>. Keep monitoring for potential risks. | |
</p> | |
""" | |
return output_message | |
# Create Gradio interface | |
inputs = gr.Image(type="pil") | |
outputs = gr.HTML() # Use HTML output for styled text | |
EXAMPLES = ["img1.jpeg", "redrot2.jpg", "rust1.jpg", "healthy2.jpeg"] | |
demo_app = gr.Interface( | |
fn=predict_disease, | |
inputs=inputs, | |
outputs=outputs, | |
title="Sugarcane Disease Detection", | |
examples=EXAMPLES, | |
live=True, | |
theme="huggingface" | |
) | |
demo_app.launch(debug=True) |