File size: 2,995 Bytes
b4563e1
 
 
 
 
 
 
 
 
 
ea04b3d
 
b4563e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a0ccc3
b4563e1
 
 
 
 
 
 
 
 
 
 
 
 
8e08196
b4563e1
ae1742e
b4563e1
 
 
 
 
 
 
 
 
a31e72e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b4563e1
ae1742e
 
a31e72e
8e08196
a31e72e
8e08196
 
ae1742e
8e08196
 
ae1742e
8e08196
 
 
 
 
a31e72e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import gradio as gr
import torch
from PIL import Image
from torchvision import transforms
import warnings
import sys
import os
import contextlib
from transformers import ViTForImageClassification

# Suppress warnings related to the model weights initialization
warnings.filterwarnings("ignore", category=UserWarning, message=".*weights.*")
warnings.filterwarnings("ignore", category=FutureWarning, module="torch")

# Suppress output for copying files and verbose model initialization messages
@contextlib.contextmanager
def suppress_stdout():
    with open(os.devnull, 'w') as devnull:
        old_stdout = sys.stdout
        sys.stdout = devnull
        try:
            yield
        finally:
            sys.stdout = old_stdout

# Load the saved model and suppress the warnings
with suppress_stdout():
    model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=6)
    model.load_state_dict(torch.load('vit_sugarcane_disease_detection.pth', map_location=torch.device('cpu')))
    model.eval()

# Define the same transformation used during training
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load the class names (disease types)
class_names = ['BacterialBlights', 'Healthy', 'Mosaic', 'RedRot', 'Rust', 'Yellow']

# Function to predict disease type from an image
def predict_disease(image):
    # Apply transformations to the image
    img_tensor = transform(image).unsqueeze(0)  # Add batch dimension

    # Make prediction
    with torch.no_grad():
        outputs = model(img_tensor)
        _, predicted_class = torch.max(outputs.logits, 1)
    
    # Get the predicted label
    predicted_label = class_names[predicted_class.item()]
    
    # Create a styled HTML output
    output_message = f"""
    <div style='font-size: 18px; color: #4CAF50; font-weight: bold;'>
        Detected Disease: <span style='color: #FF5722;'>{predicted_label}</span>
    </div>
    """
    
    if predicted_label != "Healthy":
        output_message += f"""
        <p style='font-size: 16px; color: #757575;'>
            This indicates the presence of <strong>{predicted_label}</strong>. Please take immediate action to prevent further spread.
        </p>
        """
    else:
        output_message += f"""
        <p style='font-size: 16px; color: #757575;'>
            The sugarcane crop is <strong>healthy</strong>. Keep monitoring for potential risks.
        </p>
        """
    
    return output_message

# Create Gradio interface
inputs = gr.Image(type="pil")
outputs = gr.HTML()  # Use HTML output for styled text

EXAMPLES = ["img1.jpeg", "redrot2.jpg", "rust1.jpg", "healthy2.jpeg"]

demo_app = gr.Interface(
    fn=predict_disease,
    inputs=inputs,
    outputs=outputs,
    title="Sugarcane Disease Detection",
    examples=EXAMPLES,
    live=True,
    theme="huggingface"
)

demo_app.launch(debug=True)