File size: 3,570 Bytes
b4563e1
 
 
 
 
 
 
 
7794a17
b4563e1
ea04b3d
 
b4563e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a0ccc3
b4563e1
 
 
 
 
 
 
 
 
 
 
 
7794a17
 
 
b4563e1
8e08196
b4563e1
ae1742e
b4563e1
 
 
 
 
 
 
 
7794a17
755a244
b8c2514
755a244
b8c2514
755a244
 
 
 
a31e72e
 
 
 
 
 
 
 
 
 
755a244
a31e72e
 
 
 
 
 
 
 
 
 
b4563e1
ae1742e
 
a31e72e
8e08196
a31e72e
8e08196
 
ae1742e
8e08196
 
ae1742e
8e08196
 
 
 
 
7794a17
 
a31e72e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import gradio as gr
import torch
from PIL import Image
from torchvision import transforms
import warnings
import sys
import os
import contextlib
from transformers import ViTForImageClassification, pipeline

# Suppress warnings related to the model weights initialization
warnings.filterwarnings("ignore", category=UserWarning, message=".*weights.*")
warnings.filterwarnings("ignore", category=FutureWarning, module="torch")

# Suppress output for copying files and verbose model initialization messages
@contextlib.contextmanager
def suppress_stdout():
    with open(os.devnull, 'w') as devnull:
        old_stdout = sys.stdout
        sys.stdout = devnull
        try:
            yield
        finally:
            sys.stdout = old_stdout

# Load the saved model and suppress the warnings
with suppress_stdout():
    model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=6)
    model.load_state_dict(torch.load('vit_sugarcane_disease_detection.pth', map_location=torch.device('cpu')))
    model.eval()

# Define the same transformation used during training
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load the class names (disease types)
class_names = ['BacterialBlights', 'Healthy', 'Mosaic', 'RedRot', 'Rust', 'Yellow']

# Load AI response generator (using a local GPT pipeline or OpenAI's GPT-3/4 API)
ai_pipeline = pipeline("text-generation", model="gpt2", tokenizer="gpt2")

# Function to predict disease type from an image
def predict_disease(image):
    # Apply transformations to the image
    img_tensor = transform(image).unsqueeze(0)  # Add batch dimension

    # Make prediction
    with torch.no_grad():
        outputs = model(img_tensor)
        _, predicted_class = torch.max(outputs.logits, 1)
    
    # Get the predicted label
    predicted_label = class_names[predicted_class.item()]

    # Generate a detailed response for the detected disease
    prompt = f"The detected sugarcane disease is '{predicted_label}'. Provide detailed advice for managing this condition."
    ai_response = ai_pipeline(prompt, max_length=100, num_return_sequences=1, truncation=True)[0]['generated_text']
    
    # Post-process the AI response to ensure it ends with a complete sentence
    if not ai_response.endswith(('.', '!', '?')):
        ai_response = ai_response.rsplit('.', 1)[0] + '.'

    # Create a styled HTML output
    output_message = f"""
    <div style='font-size: 18px; color: #4CAF50; font-weight: bold;'>
        Detected Disease: <span style='color: #FF5722;'>{predicted_label}</span>
    </div>
    """
    
    if predicted_label != "Healthy":
        output_message += f"""
        <p style='font-size: 16px; color: #757575;'>
            {ai_response}
        </p>
        """
    else:
        output_message += f"""
        <p style='font-size: 16px; color: #757575;'>
            The sugarcane crop is <strong>healthy</strong>. Keep monitoring for potential risks.
        </p>
        """
    
    return output_message

# Create Gradio interface
inputs = gr.Image(type="pil")
outputs = gr.HTML()  # Use HTML output for styled text

EXAMPLES = ["img1.jpeg", "redrot2.jpg", "rust1.jpg", "healthy2.jpeg"]

demo_app = gr.Interface(
    fn=predict_disease,
    inputs=inputs,
    outputs=outputs,
    title="Sugarcane Disease Detection",
    examples=EXAMPLES,
    live=True,
    theme="huggingface"
)

demo_app.launch(debug=True)

demo_app.launch(debug=True)