sunbal7 commited on
Commit
0d611dd
Β·
verified Β·
1 Parent(s): 709baf9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -73
app.py CHANGED
@@ -4,39 +4,36 @@ import torch
4
  import torch.nn as nn
5
  from torchvision import transforms
6
  import os
 
7
  from groq import Groq
8
 
9
- # Set up the app
 
 
 
10
  st.set_page_config(page_title="Leaves Disease Detection", layout="wide")
11
  st.title("🌿 Leaves Disease Detection")
12
  st.write("Upload an image of a plant leaf to check for diseases and get treatment recommendations.")
13
 
14
  # Initialize Groq client
15
  try:
16
- client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
 
17
  except Exception as e:
18
  st.error(f"Failed to initialize Groq client: {str(e)}")
19
  client = None
20
 
21
- # Simple CNN model for plant disease classification
22
  class PlantDiseaseModel(nn.Module):
23
- def __init__(self, num_classes=38):
24
  super(PlantDiseaseModel, self).__init__()
25
  self.features = nn.Sequential(
26
- nn.Conv2d(3, 32, kernel_size=3, padding=1),
27
- nn.ReLU(),
28
- nn.MaxPool2d(kernel_size=2, stride=2),
29
- nn.Conv2d(32, 64, kernel_size=3, padding=1),
30
- nn.ReLU(),
31
- nn.MaxPool2d(kernel_size=2, stride=2),
32
- nn.Conv2d(64, 128, kernel_size=3, padding=1),
33
- nn.ReLU(),
34
- nn.MaxPool2d(kernel_size=2, stride=2),
35
  )
36
  self.classifier = nn.Sequential(
37
- nn.Linear(128 * 32 * 32, 512),
38
- nn.ReLU(),
39
- nn.Dropout(0.5),
40
  nn.Linear(512, num_classes)
41
  )
42
 
@@ -46,7 +43,7 @@ class PlantDiseaseModel(nn.Module):
46
  x = self.classifier(x)
47
  return x
48
 
49
- # Load model (dummy implementation)
50
  @st.cache_resource
51
  def load_model():
52
  model = PlantDiseaseModel()
@@ -55,16 +52,30 @@ def load_model():
55
 
56
  model = load_model()
57
 
58
- # Image preprocessing
59
  def preprocess_image(image):
60
  transform = transforms.Compose([
61
  transforms.Resize((256, 256)),
62
  transforms.ToTensor(),
63
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
 
64
  ])
65
  return transform(image).unsqueeze(0)
66
 
67
- # Disease classification
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  def classify_disease(image):
69
  try:
70
  img_tensor = preprocess_image(image)
@@ -72,99 +83,69 @@ def classify_disease(image):
72
  outputs = model(img_tensor)
73
  _, predicted = torch.max(outputs, 1)
74
  class_idx = predicted.item()
75
-
76
- disease_classes = [
77
- "Healthy", "Apple Scab", "Apple Black Rot", "Apple Cedar Rust",
78
- "Cherry Powdery Mildew", "Corn Gray Leaf Spot", "Corn Common Rust",
79
- "Grape Black Rot", "Grape Esca", "Grape Leaf Blight",
80
- "Orange Huanglongbing", "Peach Bacterial Spot", "Pepper Bacterial Spot",
81
- "Potato Early Blight", "Potato Late Blight", "Raspberry Healthy",
82
- "Soybean Healthy", "Squash Powdery Mildew", "Strawberry Leaf Scorch",
83
- "Tomato Bacterial Spot", "Tomato Early Blight", "Tomato Late Blight",
84
- "Tomato Leaf Mold", "Tomato Septoria Leaf Spot", "Tomato Spider Mites",
85
- "Tomato Target Spot", "Tomato Yellow Leaf Curl Virus", "Tomato Mosaic Virus"
86
- ]
87
-
88
- disease_name = disease_classes[class_idx % len(disease_classes)]
89
- return disease_name
90
  except Exception as e:
91
  st.error(f"Error during classification: {str(e)}")
92
  return "Unknown"
93
 
94
- # Get disease info from Groq API
95
  def get_disease_info(disease_name):
96
  if not client:
97
  return {
98
- "description": "API connection not available. Please check your Groq API key.",
99
- "treatment": "",
100
- "prevention": ""
101
  }
102
-
103
  try:
104
  if disease_name.lower() == "healthy":
105
  return {
106
- "description": "The plant appears to be healthy with no visible signs of disease.",
107
- "treatment": "No treatment needed. Continue with regular plant care practices.",
108
- "prevention": "Maintain good growing conditions, proper watering, and regular monitoring."
109
  }
110
-
111
  response = client.chat.completions.create(
112
  messages=[
113
  {"role": "system", "content": "You are a plant pathologist assistant."},
114
- {"role": "user", "content": f"Describe {disease_name} in plants with symptoms, treatment, and prevention."}
115
  ],
116
  model="mixtral-8x7b-32768",
117
  temperature=0.3,
118
  max_tokens=1024
119
  )
120
-
121
  return {"description": response.choices[0].message.content}
122
  except Exception as e:
123
  st.error(f"Error fetching disease information: {str(e)}")
124
  return {
125
- "description": "Information not available. Please consult an expert.",
126
- "treatment": "",
127
- "prevention": ""
128
  }
129
 
130
- # Main app
131
  def main():
132
  uploaded_file = st.file_uploader("Upload a leaf image", type=["jpg", "jpeg", "png"])
133
 
134
- if uploaded_file is not None:
135
  try:
136
- image = Image.open(uploaded_file)
137
  st.image(image, caption="Uploaded Leaf Image", use_column_width=True)
138
-
139
- if st.button("Predict Disease"):
140
  with st.spinner("Analyzing the leaf..."):
141
  disease_name = classify_disease(image)
142
- disease_info = get_disease_info(disease_name)
143
-
144
- st.subheader("🌱 Analysis Results")
145
  col1, col2 = st.columns(2)
146
-
147
  with col1:
148
  status = "Healthy" if disease_name.lower() == "healthy" else "Diseased"
149
  st.markdown(f"**Status:** {status}")
150
- st.markdown(f"**Detected Condition:** {disease_name}")
151
-
152
  with col2:
153
- if disease_name.lower() != "healthy":
154
- st.warning("⚠️ Disease Detected")
155
  else:
156
- st.success("βœ… Healthy Plant")
157
-
158
- st.subheader("πŸ“ Detailed Information")
159
- st.write(disease_info["description"])
160
-
161
- st.subheader("🌍 Real-world Summary")
162
- if disease_name.lower() == "healthy":
163
- st.write("The analysis indicates a healthy plant leaf with no signs of disease.")
164
- else:
165
- st.write(f"The analysis detected {disease_name}. Early detection and proper treatment are crucial.")
166
  except Exception as e:
167
  st.error(f"Error processing image: {str(e)}")
168
 
169
  if __name__ == "__main__":
170
- main()
 
4
  import torch.nn as nn
5
  from torchvision import transforms
6
  import os
7
+ from dotenv import load_dotenv
8
  from groq import Groq
9
 
10
+ # Load environment variables
11
+ load_dotenv()
12
+
13
+ # Streamlit UI setup
14
  st.set_page_config(page_title="Leaves Disease Detection", layout="wide")
15
  st.title("🌿 Leaves Disease Detection")
16
  st.write("Upload an image of a plant leaf to check for diseases and get treatment recommendations.")
17
 
18
  # Initialize Groq client
19
  try:
20
+ api_key = os.getenv("GROQ_API_KEY")
21
+ client = Groq(api_key=api_key)
22
  except Exception as e:
23
  st.error(f"Failed to initialize Groq client: {str(e)}")
24
  client = None
25
 
26
+ # Simple CNN model (dummy architecture)
27
  class PlantDiseaseModel(nn.Module):
28
+ def __init__(self, num_classes=28):
29
  super(PlantDiseaseModel, self).__init__()
30
  self.features = nn.Sequential(
31
+ nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2, 2),
32
+ nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2, 2),
33
+ nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2, 2),
 
 
 
 
 
 
34
  )
35
  self.classifier = nn.Sequential(
36
+ nn.Linear(128 * 32 * 32, 512), nn.ReLU(), nn.Dropout(0.5),
 
 
37
  nn.Linear(512, num_classes)
38
  )
39
 
 
43
  x = self.classifier(x)
44
  return x
45
 
46
+ # Cache the model
47
  @st.cache_resource
48
  def load_model():
49
  model = PlantDiseaseModel()
 
52
 
53
  model = load_model()
54
 
55
+ # Preprocess image
56
  def preprocess_image(image):
57
  transform = transforms.Compose([
58
  transforms.Resize((256, 256)),
59
  transforms.ToTensor(),
60
+ transforms.Normalize([0.485, 0.456, 0.406],
61
+ [0.229, 0.224, 0.225]),
62
  ])
63
  return transform(image).unsqueeze(0)
64
 
65
+ # Dummy disease classes
66
+ disease_classes = [
67
+ "Healthy", "Apple Scab", "Apple Black Rot", "Apple Cedar Rust",
68
+ "Cherry Powdery Mildew", "Corn Gray Leaf Spot", "Corn Common Rust",
69
+ "Grape Black Rot", "Grape Esca", "Grape Leaf Blight",
70
+ "Orange Huanglongbing", "Peach Bacterial Spot", "Pepper Bacterial Spot",
71
+ "Potato Early Blight", "Potato Late Blight", "Raspberry Healthy",
72
+ "Soybean Healthy", "Squash Powdery Mildew", "Strawberry Leaf Scorch",
73
+ "Tomato Bacterial Spot", "Tomato Early Blight", "Tomato Late Blight",
74
+ "Tomato Leaf Mold", "Tomato Septoria Leaf Spot", "Tomato Spider Mites",
75
+ "Tomato Target Spot", "Tomato Yellow Leaf Curl Virus", "Tomato Mosaic Virus"
76
+ ]
77
+
78
+ # Classify the image
79
  def classify_disease(image):
80
  try:
81
  img_tensor = preprocess_image(image)
 
83
  outputs = model(img_tensor)
84
  _, predicted = torch.max(outputs, 1)
85
  class_idx = predicted.item()
86
+ return disease_classes[class_idx % len(disease_classes)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  except Exception as e:
88
  st.error(f"Error during classification: {str(e)}")
89
  return "Unknown"
90
 
91
+ # Fetch info from Groq API
92
  def get_disease_info(disease_name):
93
  if not client:
94
  return {
95
+ "description": "API connection not available. Please check your GROQ_API_KEY.",
 
 
96
  }
 
97
  try:
98
  if disease_name.lower() == "healthy":
99
  return {
100
+ "description": "The plant appears to be healthy. No treatment needed.",
 
 
101
  }
102
+
103
  response = client.chat.completions.create(
104
  messages=[
105
  {"role": "system", "content": "You are a plant pathologist assistant."},
106
+ {"role": "user", "content": f"Describe {disease_name} in plants including symptoms, treatment, and prevention."}
107
  ],
108
  model="mixtral-8x7b-32768",
109
  temperature=0.3,
110
  max_tokens=1024
111
  )
 
112
  return {"description": response.choices[0].message.content}
113
  except Exception as e:
114
  st.error(f"Error fetching disease information: {str(e)}")
115
  return {
116
+ "description": "Unable to fetch disease info. Please try again later.",
 
 
117
  }
118
 
119
+ # Main app function
120
  def main():
121
  uploaded_file = st.file_uploader("Upload a leaf image", type=["jpg", "jpeg", "png"])
122
 
123
+ if uploaded_file:
124
  try:
125
+ image = Image.open(uploaded_file).convert("RGB")
126
  st.image(image, caption="Uploaded Leaf Image", use_column_width=True)
127
+
128
+ if st.button("πŸ” Predict Disease"):
129
  with st.spinner("Analyzing the leaf..."):
130
  disease_name = classify_disease(image)
131
+ info = get_disease_info(disease_name)
132
+
133
+ st.subheader("πŸ”¬ Prediction Results")
134
  col1, col2 = st.columns(2)
 
135
  with col1:
136
  status = "Healthy" if disease_name.lower() == "healthy" else "Diseased"
137
  st.markdown(f"**Status:** {status}")
138
+ st.markdown(f"**Detected Disease:** {disease_name}")
 
139
  with col2:
140
+ if disease_name.lower() == "healthy":
141
+ st.success("βœ… Plant is Healthy")
142
  else:
143
+ st.warning("⚠️ Disease Detected")
144
+
145
+ st.subheader("πŸ“‹ Detailed Information")
146
+ st.write(info["description"])
 
 
 
 
 
 
147
  except Exception as e:
148
  st.error(f"Error processing image: {str(e)}")
149
 
150
  if __name__ == "__main__":
151
+ main()