Spaces:
Runtime error
Runtime error
File size: 3,570 Bytes
b4563e1 7794a17 b4563e1 ea04b3d b4563e1 3a0ccc3 b4563e1 7794a17 b4563e1 8e08196 b4563e1 ae1742e b4563e1 7794a17 755a244 b8c2514 755a244 b8c2514 755a244 a31e72e 755a244 a31e72e b4563e1 ae1742e a31e72e 8e08196 a31e72e 8e08196 ae1742e 8e08196 ae1742e 8e08196 7794a17 a31e72e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import gradio as gr
import torch
from PIL import Image
from torchvision import transforms
import warnings
import sys
import os
import contextlib
from transformers import ViTForImageClassification, pipeline
# Suppress warnings related to the model weights initialization
warnings.filterwarnings("ignore", category=UserWarning, message=".*weights.*")
warnings.filterwarnings("ignore", category=FutureWarning, module="torch")
# Suppress output for copying files and verbose model initialization messages
@contextlib.contextmanager
def suppress_stdout():
with open(os.devnull, 'w') as devnull:
old_stdout = sys.stdout
sys.stdout = devnull
try:
yield
finally:
sys.stdout = old_stdout
# Load the saved model and suppress the warnings
with suppress_stdout():
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=6)
model.load_state_dict(torch.load('vit_sugarcane_disease_detection.pth', map_location=torch.device('cpu')))
model.eval()
# Define the same transformation used during training
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Load the class names (disease types)
class_names = ['BacterialBlights', 'Healthy', 'Mosaic', 'RedRot', 'Rust', 'Yellow']
# Load AI response generator (using a local GPT pipeline or OpenAI's GPT-3/4 API)
ai_pipeline = pipeline("text-generation", model="gpt2", tokenizer="gpt2")
# Function to predict disease type from an image
def predict_disease(image):
# Apply transformations to the image
img_tensor = transform(image).unsqueeze(0) # Add batch dimension
# Make prediction
with torch.no_grad():
outputs = model(img_tensor)
_, predicted_class = torch.max(outputs.logits, 1)
# Get the predicted label
predicted_label = class_names[predicted_class.item()]
# Generate a detailed response for the detected disease
prompt = f"The detected sugarcane disease is '{predicted_label}'. Provide detailed advice for managing this condition."
ai_response = ai_pipeline(prompt, max_length=100, num_return_sequences=1, truncation=True)[0]['generated_text']
# Post-process the AI response to ensure it ends with a complete sentence
if not ai_response.endswith(('.', '!', '?')):
ai_response = ai_response.rsplit('.', 1)[0] + '.'
# Create a styled HTML output
output_message = f"""
<div style='font-size: 18px; color: #4CAF50; font-weight: bold;'>
Detected Disease: <span style='color: #FF5722;'>{predicted_label}</span>
</div>
"""
if predicted_label != "Healthy":
output_message += f"""
<p style='font-size: 16px; color: #757575;'>
{ai_response}
</p>
"""
else:
output_message += f"""
<p style='font-size: 16px; color: #757575;'>
The sugarcane crop is <strong>healthy</strong>. Keep monitoring for potential risks.
</p>
"""
return output_message
# Create Gradio interface
inputs = gr.Image(type="pil")
outputs = gr.HTML() # Use HTML output for styled text
EXAMPLES = ["img1.jpeg", "redrot2.jpg", "rust1.jpg", "healthy2.jpeg"]
demo_app = gr.Interface(
fn=predict_disease,
inputs=inputs,
outputs=outputs,
title="Sugarcane Disease Detection",
examples=EXAMPLES,
live=True,
theme="huggingface"
)
demo_app.launch(debug=True)
demo_app.launch(debug=True) |