Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from PIL import Image | |
from torchvision import transforms | |
import warnings | |
import sys | |
import os | |
import contextlib | |
from transformers import ViTForImageClassification, pipeline | |
# Suppress warnings related to the model weights initialization | |
warnings.filterwarnings("ignore", category=UserWarning, message=".*weights.*") | |
warnings.filterwarnings("ignore", category=FutureWarning, module="torch") | |
# Suppress output for copying files and verbose model initialization messages | |
def suppress_stdout(): | |
with open(os.devnull, 'w') as devnull: | |
old_stdout = sys.stdout | |
sys.stdout = devnull | |
try: | |
yield | |
finally: | |
sys.stdout = old_stdout | |
# Load the saved model and suppress the warnings | |
with suppress_stdout(): | |
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=6) | |
model.load_state_dict(torch.load('vit_sugarcane_disease_detection.pth', map_location=torch.device('cpu'))) | |
model.eval() | |
# Define the same transformation used during training | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
]) | |
# Load the class names (disease types) | |
class_names = ['BacterialBlights', 'Healthy', 'Mosaic', 'RedRot', 'Rust', 'Yellow'] | |
# Load AI response generator (using a local GPT pipeline or OpenAI's GPT-3/4 API) | |
ai_pipeline = pipeline("text-generation", model="gpt2", tokenizer="gpt2") | |
# Knowledge base for sugarcane diseases (example data from the website) | |
knowledge_base = { | |
'BacterialBlights': "Bacterial blights cause water-soaked lesions on leaves, leading to yellowing and withering. To manage, apply copper-based fungicides and ensure proper drainage.", | |
'Mosaic': "Mosaic disease results in streaked and mottled leaves, reducing photosynthesis. Use disease-resistant varieties and control aphids to prevent spread.", | |
'RedRot': "Red rot is identified by reddening and rotting of stalks. Remove infected plants and treat soil with appropriate fungicides.", | |
'Rust': "Rust appears as orange pustules on leaves. Apply systemic fungicides and maintain optimal field conditions to reduce spread.", | |
'Yellow': "Yellowing indicates nutrient deficiencies or initial disease stages. Test soil and provide balanced fertilizers.", | |
'Healthy': "The sugarcane crop is healthy. Continue regular monitoring and good agronomic practices." | |
} | |
# Update the predict_disease function | |
def predict_disease(image): | |
# Apply transformations to the image | |
img_tensor = transform(image).unsqueeze(0) # Add batch dimension | |
# Make prediction | |
with torch.no_grad(): | |
outputs = model(img_tensor) | |
_, predicted_class = torch.max(outputs.logits, 1) | |
# Get the predicted label | |
predicted_label = class_names[predicted_class.item()] | |
# Retrieve response from knowledge base | |
if predicted_label in knowledge_base: | |
detailed_response = knowledge_base[predicted_label] | |
else: | |
# Fallback to AI-generated response | |
prompt = f"The detected sugarcane disease is '{predicted_label}'. Provide detailed advice for managing this condition." | |
detailed_response = ai_pipeline(prompt, max_length=100, num_return_sequences=1, truncation=True)[0]['generated_text'] | |
# Create a styled HTML output | |
output_message = f""" | |
<div style='font-size: 18px; color: #4CAF50; font-weight: bold;'> | |
Detected Disease: <span style='color: #FF5722;'>{predicted_label}</span> | |
</div> | |
""" | |
if predicted_label != "Healthy": | |
output_message += f""" | |
<p style='font-size: 16px; color: #757575;'> | |
{detailed_response} | |
</p> | |
""" | |
else: | |
output_message += f""" | |
<p style='font-size: 16px; color: #757575;'> | |
{detailed_response} | |
</p> | |
""" | |
return output_message | |
# Create Gradio interface | |
inputs = gr.Image(type="pil") | |
outputs = gr.HTML() # Use HTML output for styled text | |
EXAMPLES = ["img1.jpeg", "redrot2.jpg", "rust1.jpg", "healthy2.jpeg"] | |
demo_app = gr.Interface( | |
fn=predict_disease, | |
inputs=inputs, | |
outputs=outputs, | |
title="Sugarcane Disease Detection", | |
examples=EXAMPLES, | |
live=True, | |
theme="huggingface" | |
) | |
demo_app.launch(debug=True) | |