sunbal7 commited on
Commit
963bae3
Β·
verified Β·
1 Parent(s): 5951091

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -56
app.py CHANGED
@@ -2,6 +2,7 @@ import streamlit as st
2
  import requests
3
  from PIL import Image
4
  import torch
 
5
  from torchvision import transforms
6
  import os
7
  from groq import Groq
@@ -9,15 +10,45 @@ from groq import Groq
9
  # Initialize Groq client
10
  client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
11
 
12
- # Set up the app title and layout
13
  st.set_page_config(page_title="Leaves Disease Detection", layout="wide")
14
  st.title("🌿 Leaves Disease Detection")
15
  st.write("Upload an image of a plant leaf to check for diseases and get treatment recommendations.")
16
 
17
- # Load the plant disease classification model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  @st.cache_resource
19
  def load_model():
20
- model = torch.hub.load('ultralytics/yolov5', 'custom', path='plant_disease_model.pt', force_reload=True)
 
 
21
  model.eval()
22
  return model
23
 
@@ -46,7 +77,7 @@ def classify_disease(image):
46
  _, predicted = torch.max(outputs, 1)
47
  class_idx = predicted.item()
48
 
49
- # Map class index to disease name (this should match your model's classes)
50
  disease_classes = [
51
  "Healthy", "Apple Scab", "Apple Black Rot", "Apple Cedar Rust",
52
  "Cherry Powdery Mildew", "Corn Gray Leaf Spot", "Corn Common Rust",
@@ -59,13 +90,13 @@ def classify_disease(image):
59
  "Tomato Target Spot", "Tomato Yellow Leaf Curl Virus", "Tomato Mosaic Virus"
60
  ]
61
 
62
- disease_name = disease_classes[class_idx]
63
  return disease_name
64
  except Exception as e:
65
  st.error(f"Error during classification: {str(e)}")
66
  return "Unknown"
67
 
68
- # Get disease description and treatment from Groq API
69
  def get_disease_info(disease_name):
70
  try:
71
  if disease_name.lower() == "healthy":
@@ -79,11 +110,11 @@ def get_disease_info(disease_name):
79
  messages=[
80
  {
81
  "role": "system",
82
- "content": "You are a plant pathologist assistant. Provide accurate information about plant diseases, their symptoms, treatments, and prevention methods."
83
  },
84
  {
85
  "role": "user",
86
- "content": f"Provide a detailed description of the plant disease {disease_name}, its symptoms, recommended treatments (including organic and chemical options), and prevention methods. Format the response with clear sections for Description, Treatment, and Prevention."
87
  }
88
  ],
89
  model="mixtral-8x7b-32768",
@@ -91,59 +122,28 @@ def get_disease_info(disease_name):
91
  max_tokens=1024
92
  )
93
 
94
- response = chat_completion.choices[0].message.content
95
- return parse_groq_response(response)
96
  except Exception as e:
97
  st.error(f"Error fetching disease information: {str(e)}")
98
  return {
99
- "description": "Information not available.",
100
- "treatment": "Please consult a local agricultural expert.",
101
- "prevention": "Regular monitoring and good cultural practices are recommended."
102
  }
103
 
104
- def parse_groq_response(response):
105
- # Simple parsing of the Groq response
106
- sections = {
107
- "description": "",
108
- "treatment": "",
109
- "prevention": ""
110
- }
111
-
112
- current_section = None
113
- for line in response.split('\n'):
114
- line_lower = line.lower()
115
- if "description" in line_lower:
116
- current_section = "description"
117
- sections[current_section] += line + "\n"
118
- elif "treatment" in line_lower:
119
- current_section = "treatment"
120
- sections[current_section] += line + "\n"
121
- elif "prevention" in line_lower:
122
- current_section = "prevention"
123
- sections[current_section] += line + "\n"
124
- elif current_section:
125
- sections[current_section] += line + "\n"
126
-
127
- return sections
128
-
129
- # Main application
130
  def main():
131
  uploaded_file = st.file_uploader("Upload a leaf image", type=["jpg", "jpeg", "png"])
132
 
133
  if uploaded_file is not None:
134
- # Display the uploaded image
135
  image = Image.open(uploaded_file)
136
  st.image(image, caption="Uploaded Leaf Image", use_column_width=True)
137
 
138
  if st.button("Predict Disease"):
139
  with st.spinner("Analyzing the leaf..."):
140
- # Classify the disease
141
  disease_name = classify_disease(image)
142
-
143
- # Get disease information
144
  disease_info = get_disease_info(disease_name)
145
 
146
- # Display results
147
  st.subheader("🌱 Analysis Results")
148
  col1, col2 = st.columns(2)
149
 
@@ -158,22 +158,13 @@ def main():
158
  st.success("βœ… Healthy Plant")
159
 
160
  st.subheader("πŸ“ Detailed Information")
161
-
162
- with st.expander("Description"):
163
- st.write(disease_info["description"])
164
-
165
- if disease_name.lower() != "healthy":
166
- with st.expander("Recommended Treatment"):
167
- st.write(disease_info["treatment"])
168
-
169
- with st.expander("Prevention Methods"):
170
- st.write(disease_info["prevention"])
171
 
172
  st.subheader("🌍 Real-world Summary")
173
  if disease_name.lower() == "healthy":
174
- st.write(f"The analysis indicates that the plant leaf appears healthy with no signs of disease. Healthy plants typically have vibrant color, uniform texture, and no visible spots or discoloration. Maintaining proper growing conditions, including adequate sunlight, water, and nutrients, will help keep the plant healthy.")
175
  else:
176
- st.write(f"The analysis detected {disease_name}, a common plant disease that can affect plant health and productivity. {disease_info['description'].split('.')[0]}. Early detection and proper treatment are crucial to prevent the spread of the disease and protect other plants in the vicinity. Following the recommended treatment and prevention methods can help restore plant health and prevent future outbreaks.")
177
 
178
  if __name__ == "__main__":
179
  main()
 
2
  import requests
3
  from PIL import Image
4
  import torch
5
+ import torch.nn as nn
6
  from torchvision import transforms
7
  import os
8
  from groq import Groq
 
10
  # Initialize Groq client
11
  client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
12
 
13
+ # Set up the app
14
  st.set_page_config(page_title="Leaves Disease Detection", layout="wide")
15
  st.title("🌿 Leaves Disease Detection")
16
  st.write("Upload an image of a plant leaf to check for diseases and get treatment recommendations.")
17
 
18
+ # Simple CNN model for plant disease classification
19
+ class PlantDiseaseModel(nn.Module):
20
+ def __init__(self, num_classes=38):
21
+ super(PlantDiseaseModel, self).__init__()
22
+ self.features = nn.Sequential(
23
+ nn.Conv2d(3, 32, kernel_size=3, padding=1),
24
+ nn.ReLU(),
25
+ nn.MaxPool2d(kernel_size=2, stride=2),
26
+ nn.Conv2d(32, 64, kernel_size=3, padding=1),
27
+ nn.ReLU(),
28
+ nn.MaxPool2d(kernel_size=2, stride=2),
29
+ nn.Conv2d(64, 128, kernel_size=3, padding=1),
30
+ nn.ReLU(),
31
+ nn.MaxPool2d(kernel_size=2, stride=2),
32
+ )
33
+ self.classifier = nn.Sequential(
34
+ nn.Linear(128 * 32 * 32, 512),
35
+ nn.ReLU(),
36
+ nn.Dropout(0.5),
37
+ nn.Linear(512, num_classes)
38
+ )
39
+
40
+ def forward(self, x):
41
+ x = self.features(x)
42
+ x = x.view(x.size(0), -1)
43
+ x = self.classifier(x)
44
+ return x
45
+
46
+ # Load model (dummy implementation - in practice you'd load trained weights)
47
  @st.cache_resource
48
  def load_model():
49
+ model = PlantDiseaseModel()
50
+ # In a real app, you would load pre-trained weights here
51
+ # model.load_state_dict(torch.load('model_weights.pth'))
52
  model.eval()
53
  return model
54
 
 
77
  _, predicted = torch.max(outputs, 1)
78
  class_idx = predicted.item()
79
 
80
+ # Simplified disease classes (adjust based on your model)
81
  disease_classes = [
82
  "Healthy", "Apple Scab", "Apple Black Rot", "Apple Cedar Rust",
83
  "Cherry Powdery Mildew", "Corn Gray Leaf Spot", "Corn Common Rust",
 
90
  "Tomato Target Spot", "Tomato Yellow Leaf Curl Virus", "Tomato Mosaic Virus"
91
  ]
92
 
93
+ disease_name = disease_classes[class_idx % len(disease_classes)] # Ensure index is valid
94
  return disease_name
95
  except Exception as e:
96
  st.error(f"Error during classification: {str(e)}")
97
  return "Unknown"
98
 
99
+ # Get disease info from Groq API
100
  def get_disease_info(disease_name):
101
  try:
102
  if disease_name.lower() == "healthy":
 
110
  messages=[
111
  {
112
  "role": "system",
113
+ "content": "You are a plant pathologist assistant. Provide accurate information about plant diseases."
114
  },
115
  {
116
  "role": "user",
117
+ "content": f"Provide information about {disease_name} in plants with description, treatment, and prevention."
118
  }
119
  ],
120
  model="mixtral-8x7b-32768",
 
122
  max_tokens=1024
123
  )
124
 
125
+ return {"description": chat_completion.choices[0].message.content}
 
126
  except Exception as e:
127
  st.error(f"Error fetching disease information: {str(e)}")
128
  return {
129
+ "description": "Information not available. Please consult an expert.",
130
+ "treatment": "",
131
+ "prevention": ""
132
  }
133
 
134
+ # Main app
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  def main():
136
  uploaded_file = st.file_uploader("Upload a leaf image", type=["jpg", "jpeg", "png"])
137
 
138
  if uploaded_file is not None:
 
139
  image = Image.open(uploaded_file)
140
  st.image(image, caption="Uploaded Leaf Image", use_column_width=True)
141
 
142
  if st.button("Predict Disease"):
143
  with st.spinner("Analyzing the leaf..."):
 
144
  disease_name = classify_disease(image)
 
 
145
  disease_info = get_disease_info(disease_name)
146
 
 
147
  st.subheader("🌱 Analysis Results")
148
  col1, col2 = st.columns(2)
149
 
 
158
  st.success("βœ… Healthy Plant")
159
 
160
  st.subheader("πŸ“ Detailed Information")
161
+ st.write(disease_info["description"])
 
 
 
 
 
 
 
 
 
162
 
163
  st.subheader("🌍 Real-world Summary")
164
  if disease_name.lower() == "healthy":
165
+ st.write("The analysis indicates a healthy plant leaf with no signs of disease.")
166
  else:
167
+ st.write(f"The analysis detected {disease_name}. Early detection and proper treatment are crucial.")
168
 
169
  if __name__ == "__main__":
170
  main()