Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,76 +1,179 @@
|
|
1 |
import streamlit as st
|
2 |
-
import
|
3 |
from PIL import Image
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
-
# Load
|
7 |
@st.cache_resource
|
8 |
def load_model():
|
9 |
-
|
10 |
-
model
|
11 |
-
return
|
12 |
|
13 |
-
|
14 |
|
15 |
-
#
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
# Add all 38 classes from PlantVillage dataset
|
24 |
-
# Complete mapping available at: https://github.com/nateraw/plant-village-classifier
|
25 |
-
}
|
26 |
|
27 |
-
#
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
]
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
st.
|
|
|
49 |
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
-
|
53 |
-
|
54 |
-
|
|
|
|
|
|
|
|
|
55 |
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
-
|
66 |
-
|
|
|
|
|
|
|
67 |
|
68 |
-
if
|
69 |
-
|
70 |
-
|
|
|
71 |
|
72 |
-
st.
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
+
import requests
|
3 |
from PIL import Image
|
4 |
+
import torch
|
5 |
+
from torchvision import transforms
|
6 |
+
import os
|
7 |
+
from groq import Groq
|
8 |
+
|
9 |
+
# Initialize Groq client
|
10 |
+
client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
|
11 |
+
|
12 |
+
# Set up the app title and layout
|
13 |
+
st.set_page_config(page_title="Leaves Disease Detection", layout="wide")
|
14 |
+
st.title("πΏ Leaves Disease Detection")
|
15 |
+
st.write("Upload an image of a plant leaf to check for diseases and get treatment recommendations.")
|
16 |
|
17 |
+
# Load the plant disease classification model
|
18 |
@st.cache_resource
|
19 |
def load_model():
|
20 |
+
model = torch.hub.load('ultralytics/yolov5', 'custom', path='plant_disease_model.pt', force_reload=True)
|
21 |
+
model.eval()
|
22 |
+
return model
|
23 |
|
24 |
+
model = load_model()
|
25 |
|
26 |
+
# Image preprocessing
|
27 |
+
def preprocess_image(image):
|
28 |
+
transform = transforms.Compose([
|
29 |
+
transforms.Resize((256, 256)),
|
30 |
+
transforms.ToTensor(),
|
31 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
32 |
+
])
|
33 |
+
return transform(image).unsqueeze(0)
|
|
|
|
|
|
|
34 |
|
35 |
+
# Disease classification
|
36 |
+
def classify_disease(image):
|
37 |
+
try:
|
38 |
+
# Preprocess the image
|
39 |
+
img_tensor = preprocess_image(image)
|
40 |
+
|
41 |
+
# Perform inference
|
42 |
+
with torch.no_grad():
|
43 |
+
outputs = model(img_tensor)
|
44 |
+
|
45 |
+
# Get the predicted class
|
46 |
+
_, predicted = torch.max(outputs, 1)
|
47 |
+
class_idx = predicted.item()
|
48 |
+
|
49 |
+
# Map class index to disease name (this should match your model's classes)
|
50 |
+
disease_classes = [
|
51 |
+
"Healthy", "Apple Scab", "Apple Black Rot", "Apple Cedar Rust",
|
52 |
+
"Cherry Powdery Mildew", "Corn Gray Leaf Spot", "Corn Common Rust",
|
53 |
+
"Grape Black Rot", "Grape Esca", "Grape Leaf Blight",
|
54 |
+
"Orange Huanglongbing", "Peach Bacterial Spot", "Pepper Bacterial Spot",
|
55 |
+
"Potato Early Blight", "Potato Late Blight", "Raspberry Healthy",
|
56 |
+
"Soybean Healthy", "Squash Powdery Mildew", "Strawberry Leaf Scorch",
|
57 |
+
"Tomato Bacterial Spot", "Tomato Early Blight", "Tomato Late Blight",
|
58 |
+
"Tomato Leaf Mold", "Tomato Septoria Leaf Spot", "Tomato Spider Mites",
|
59 |
+
"Tomato Target Spot", "Tomato Yellow Leaf Curl Virus", "Tomato Mosaic Virus"
|
60 |
]
|
61 |
+
|
62 |
+
disease_name = disease_classes[class_idx]
|
63 |
+
return disease_name
|
64 |
+
except Exception as e:
|
65 |
+
st.error(f"Error during classification: {str(e)}")
|
66 |
+
return "Unknown"
|
67 |
|
68 |
+
# Get disease description and treatment from Groq API
|
69 |
+
def get_disease_info(disease_name):
|
70 |
+
try:
|
71 |
+
if disease_name.lower() == "healthy":
|
72 |
+
return {
|
73 |
+
"description": "The plant appears to be healthy with no visible signs of disease.",
|
74 |
+
"treatment": "No treatment needed. Continue with regular plant care practices.",
|
75 |
+
"prevention": "Maintain good growing conditions, proper watering, and regular monitoring."
|
76 |
+
}
|
77 |
+
|
78 |
+
chat_completion = client.chat.completions.create(
|
79 |
+
messages=[
|
80 |
+
{
|
81 |
+
"role": "system",
|
82 |
+
"content": "You are a plant pathologist assistant. Provide accurate information about plant diseases, their symptoms, treatments, and prevention methods."
|
83 |
+
},
|
84 |
+
{
|
85 |
+
"role": "user",
|
86 |
+
"content": f"Provide a detailed description of the plant disease {disease_name}, its symptoms, recommended treatments (including organic and chemical options), and prevention methods. Format the response with clear sections for Description, Treatment, and Prevention."
|
87 |
+
}
|
88 |
+
],
|
89 |
+
model="mixtral-8x7b-32768",
|
90 |
+
temperature=0.3,
|
91 |
+
max_tokens=1024
|
92 |
+
)
|
93 |
+
|
94 |
+
response = chat_completion.choices[0].message.content
|
95 |
+
return parse_groq_response(response)
|
96 |
+
except Exception as e:
|
97 |
+
st.error(f"Error fetching disease information: {str(e)}")
|
98 |
+
return {
|
99 |
+
"description": "Information not available.",
|
100 |
+
"treatment": "Please consult a local agricultural expert.",
|
101 |
+
"prevention": "Regular monitoring and good cultural practices are recommended."
|
102 |
+
}
|
103 |
|
104 |
+
def parse_groq_response(response):
|
105 |
+
# Simple parsing of the Groq response
|
106 |
+
sections = {
|
107 |
+
"description": "",
|
108 |
+
"treatment": "",
|
109 |
+
"prevention": ""
|
110 |
+
}
|
111 |
|
112 |
+
current_section = None
|
113 |
+
for line in response.split('\n'):
|
114 |
+
line_lower = line.lower()
|
115 |
+
if "description" in line_lower:
|
116 |
+
current_section = "description"
|
117 |
+
sections[current_section] += line + "\n"
|
118 |
+
elif "treatment" in line_lower:
|
119 |
+
current_section = "treatment"
|
120 |
+
sections[current_section] += line + "\n"
|
121 |
+
elif "prevention" in line_lower:
|
122 |
+
current_section = "prevention"
|
123 |
+
sections[current_section] += line + "\n"
|
124 |
+
elif current_section:
|
125 |
+
sections[current_section] += line + "\n"
|
126 |
|
127 |
+
return sections
|
128 |
+
|
129 |
+
# Main application
|
130 |
+
def main():
|
131 |
+
uploaded_file = st.file_uploader("Upload a leaf image", type=["jpg", "jpeg", "png"])
|
132 |
|
133 |
+
if uploaded_file is not None:
|
134 |
+
# Display the uploaded image
|
135 |
+
image = Image.open(uploaded_file)
|
136 |
+
st.image(image, caption="Uploaded Leaf Image", use_column_width=True)
|
137 |
|
138 |
+
if st.button("Predict Disease"):
|
139 |
+
with st.spinner("Analyzing the leaf..."):
|
140 |
+
# Classify the disease
|
141 |
+
disease_name = classify_disease(image)
|
142 |
+
|
143 |
+
# Get disease information
|
144 |
+
disease_info = get_disease_info(disease_name)
|
145 |
+
|
146 |
+
# Display results
|
147 |
+
st.subheader("π± Analysis Results")
|
148 |
+
col1, col2 = st.columns(2)
|
149 |
+
|
150 |
+
with col1:
|
151 |
+
st.markdown(f"**Status:** {'Healthy' if disease_name.lower() == 'healthy' else 'Diseased'}")
|
152 |
+
st.markdown(f"**Detected Condition:** {disease_name}")
|
153 |
+
|
154 |
+
with col2:
|
155 |
+
if disease_name.lower() != "healthy":
|
156 |
+
st.warning("β οΈ Disease Detected")
|
157 |
+
else:
|
158 |
+
st.success("β
Healthy Plant")
|
159 |
+
|
160 |
+
st.subheader("π Detailed Information")
|
161 |
+
|
162 |
+
with st.expander("Description"):
|
163 |
+
st.write(disease_info["description"])
|
164 |
+
|
165 |
+
if disease_name.lower() != "healthy":
|
166 |
+
with st.expander("Recommended Treatment"):
|
167 |
+
st.write(disease_info["treatment"])
|
168 |
+
|
169 |
+
with st.expander("Prevention Methods"):
|
170 |
+
st.write(disease_info["prevention"])
|
171 |
+
|
172 |
+
st.subheader("π Real-world Summary")
|
173 |
+
if disease_name.lower() == "healthy":
|
174 |
+
st.write(f"The analysis indicates that the plant leaf appears healthy with no signs of disease. Healthy plants typically have vibrant color, uniform texture, and no visible spots or discoloration. Maintaining proper growing conditions, including adequate sunlight, water, and nutrients, will help keep the plant healthy.")
|
175 |
+
else:
|
176 |
+
st.write(f"The analysis detected {disease_name}, a common plant disease that can affect plant health and productivity. {disease_info['description'].split('.')[0]}. Early detection and proper treatment are crucial to prevent the spread of the disease and protect other plants in the vicinity. Following the recommended treatment and prevention methods can help restore plant health and prevent future outbreaks.")
|
177 |
+
|
178 |
+
if __name__ == "__main__":
|
179 |
+
main()
|