File size: 4,403 Bytes
b4563e1
 
 
 
 
 
 
 
7794a17
b4563e1
d521450
ea04b3d
b4563e1
 
d521450
b4563e1
 
 
 
 
 
 
 
 
 
59542a9
b4563e1
 
3a0ccc3
b4563e1
 
 
 
 
 
 
 
 
d521450
b4563e1
 
d521450
59542a9
 
d521450
8f6a2da
 
 
 
 
 
 
 
 
d521450
8e08196
b4563e1
59542a9
b4563e1
 
 
 
d521450
b4563e1
 
 
7794a17
8f6a2da
59542a9
 
 
 
 
 
 
 
a31e72e
 
 
 
 
59542a9
 
 
 
 
 
 
 
 
 
 
 
 
 
a31e72e
b4563e1
ae1742e
 
d521450
8e08196
59542a9
 
8e08196
ae1742e
8e08196
 
ae1742e
59542a9
 
 
8e08196
 
7794a17
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import gradio as gr
import torch
from PIL import Image
from torchvision import transforms
import warnings
import sys
import os
import contextlib
from transformers import ViTForImageClassification, pipeline

# Suppress warnings related to the model weights initialization
warnings.filterwarnings("ignore", category=UserWarning, message=".*weights.*")
warnings.filterwarnings("ignore", category=FutureWarning, module="torch")

# Suppress output for copying files and verbose model initialization messages
@contextlib.contextmanager
def suppress_stdout():
    with open(os.devnull, 'w') as devnull:
        old_stdout = sys.stdout
        sys.stdout = devnull
        try:
            yield
        finally:
            sys.stdout = old_stdout

# Load the saved model and suppress the warnings
with suppress_stdout():
    model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=6)
    model.load_state_dict(torch.load('vit_sugarcane_disease_detection.pth', map_location=torch.device('cpu')))
    model.eval()

# Define the same transformation used during training
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load the class names (disease types)
class_names = ['BacterialBlights', 'Healthy', 'Mosaic', 'RedRot', 'Rust', 'Yellow']

# Load AI response generator (using a local GPT pipeline or OpenAI's GPT-3/4 API)
ai_pipeline = pipeline("text-generation", model="gpt2", tokenizer="gpt2")

# Knowledge base for sugarcane diseases (example data from the website)
knowledge_base = {
    'BacterialBlights': "Bacterial blights cause water-soaked lesions on leaves, leading to yellowing and withering. To manage, apply copper-based fungicides and ensure proper drainage.",
    'Mosaic': "Mosaic disease results in streaked and mottled leaves, reducing photosynthesis. Use disease-resistant varieties and control aphids to prevent spread.",
    'RedRot': "Red rot is identified by reddening and rotting of stalks. Remove infected plants and treat soil with appropriate fungicides.",
    'Rust': "Rust appears as orange pustules on leaves. Apply systemic fungicides and maintain optimal field conditions to reduce spread.",
    'Yellow': "Yellowing indicates nutrient deficiencies or initial disease stages. Test soil and provide balanced fertilizers.",
    'Healthy': "The sugarcane crop is healthy. Continue regular monitoring and good agronomic practices."
}

# Update the predict_disease function
def predict_disease(image):
    # Apply transformations to the image
    img_tensor = transform(image).unsqueeze(0)  # Add batch dimension

    # Make prediction
    with torch.no_grad():
        outputs = model(img_tensor)
        _, predicted_class = torch.max(outputs.logits, 1)
    
    # Get the predicted label
    predicted_label = class_names[predicted_class.item()]

    # Retrieve response from knowledge base
    if predicted_label in knowledge_base:
        detailed_response = knowledge_base[predicted_label]
    else:
        # Fallback to AI-generated response
        prompt = f"The detected sugarcane disease is '{predicted_label}'. Provide detailed advice for managing this condition."
        detailed_response = ai_pipeline(prompt, max_length=100, num_return_sequences=1, truncation=True)[0]['generated_text']

    # Create a styled HTML output
    output_message = f"""
    <div style='font-size: 18px; color: #4CAF50; font-weight: bold;'>
        Detected Disease: <span style='color: #FF5722;'>{predicted_label}</span>
    </div>
    """
    
    if predicted_label != "Healthy":
        output_message += f"""
        <p style='font-size: 16px; color: #757575;'>
            {detailed_response}
        </p>
        """
    else:
        output_message += f"""
        <p style='font-size: 16px; color: #757575;'>
            {detailed_response}
        </p>
        """
    
    return output_message

# Create Gradio interface
inputs = gr.Image(type="pil")
outputs = gr.HTML()  # Use HTML output for styled text

EXAMPLES = ["img1.jpeg", "redrot2.jpg", "rust1.jpg", "healthy2.jpeg"]

demo_app = gr.Interface(
    fn=predict_disease,
    inputs=inputs,
    outputs=outputs,
    title="Sugarcane Disease Detection",
    examples=EXAMPLES,
    live=True,
    theme="huggingface"
)

demo_app.launch(debug=True)