MuhammmadRizwanRizwan commited on
Commit
1912d57
·
verified ·
1 Parent(s): ea250f3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +264 -88
app.py CHANGED
@@ -1,40 +1,142 @@
1
 
2
- import streamlit as st
3
- import numpy as np
4
- import cv2
5
- import warnings
6
- import os
7
 
8
- # Suppress warnings
9
- warnings.filterwarnings("ignore", category=FutureWarning)
10
- warnings.filterwarnings("ignore", category=UserWarning)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- # Try importing TensorFlow
13
- try:
14
- from tensorflow.keras.models import load_model
15
- from tensorflow.keras.preprocessing import image
16
- except ImportError:
17
- st.error("Failed to import TensorFlow. Please make sure it's installed correctly.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- # Try importing PyTorch and Detectron2
20
- try:
21
- import torch
22
- import detectron2
23
- except ImportError:
24
- with st.spinner("Installing PyTorch and Detectron2..."):
25
- os.system("pip install torch torchvision")
26
- os.system("pip install 'git+https://github.com/facebookresearch/detectron2.git'")
27
 
28
- import torch
29
- import detectron2
30
 
31
 
32
  import streamlit as st
33
  import numpy as np
34
  import cv2
35
- import torch
36
  import os
 
37
  from PIL import Image
 
38
  from tensorflow.keras.models import load_model
39
  from tensorflow.keras.preprocessing import image
40
  from detectron2.engine import DefaultPredictor
@@ -43,85 +145,159 @@ from detectron2.utils.visualizer import Visualizer
43
  from detectron2.data import MetadataCatalog
44
 
45
  # Suppress warnings
46
- import warnings
47
- import tensorflow as tf
48
  warnings.filterwarnings("ignore")
49
  tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
50
 
51
- @st.cache_resource
52
- def load_models():
53
- model_name = load_model('name_model_inception.h5')
54
- model_quality = load_model('type_model_inception.h5')
55
- return model_name, model_quality
 
 
56
 
57
- model_name, model_quality = load_models()
 
 
 
 
 
 
 
58
 
59
- # Detectron2 setup
60
  @st.cache_resource
61
- def load_detectron_model(fruit_name):
62
- cfg = get_cfg()
63
- config_path = os.path.join(f"{fruit_name.lower()}_config.yaml")
64
- cfg.merge_from_file(config_path)
65
- model_path = os.path.join(f"{fruit_name}_model.pth")
66
- cfg.MODEL.WEIGHTS = model_path
67
- cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
68
- cfg.MODEL.DEVICE = 'cpu'
69
- predictor = DefaultPredictor(cfg)
70
- return predictor, cfg
71
-
72
- # Labels
73
- label_map_name = {
74
- 0: "Banana", 1: "Cucumber", 2: "Grape", 3: "Kaki", 4: "Papaya",
75
- 5: "Peach", 6: "Pear", 7: "Peeper", 8: "Strawberry", 9: "Watermelon",
76
- 10: "tomato"
77
- }
78
- label_map_quality = {0: "Good", 1: "Mild", 2: "Rotten"}
79
 
80
- def predict_fruit(img):
81
- # Preprocess image
82
- img = Image.fromarray(img.astype('uint8'), 'RGB')
83
- img = img.resize((224, 224))
84
- x = image.img_to_array(img)
85
- x = np.expand_dims(x, axis=0)
86
- x = x / 255.0
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
- # Predict
89
- pred_name = model_name.predict(x)
90
- pred_quality = model_quality.predict(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
- predicted_name = label_map_name[np.argmax(pred_name, axis=1)[0]]
93
- predicted_quality = label_map_quality[np.argmax(pred_quality, axis=1)[0]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
- return predicted_name, predicted_quality, img
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  def main():
98
- st.title("Automated Fruits Monitoring System")
 
 
99
  st.write("Upload an image of a fruit to detect its type, quality, and potential damage.")
100
-
101
  uploaded_file = st.file_uploader("Choose a fruit image...", type=["jpg", "jpeg", "png"])
102
-
103
  if uploaded_file is not None:
 
 
 
 
104
  image = Image.open(uploaded_file)
105
- st.image(image, caption="Uploaded Image", use_column_width=True)
106
-
107
- if st.button("Analyze"):
108
- predicted_name, predicted_quality, img = predict_fruit(np.array(image))
109
-
110
- st.write(f"Fruits Type Detection: {predicted_name}")
111
- st.write(f"Fruits Quality Classification: {predicted_quality}")
112
-
113
- if predicted_name.lower() in ["kaki", "tomato", "strawberry", "peeper", "pear", "peach", "papaya", "watermelon", "grape", "banana", "cucumber"] and predicted_quality in ["Mild", "Rotten"]:
114
- st.write("Segmentation of Defective Region:")
115
- try:
116
- predictor, cfg = load_detectron_model(predicted_name)
117
- outputs = predictor(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))
118
- v = Visualizer(np.array(img), MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=0.8)
119
- out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
120
- st.image(out.get_image(), caption="Damage Detection Result", use_column_width=True)
121
- except Exception as e:
122
- st.error(f"Error in damage detection: {str(e)}")
123
- else:
124
- st.write("No damage detection performed for this fruit or quality level.")
 
 
 
 
 
 
 
 
 
 
125
 
126
  if __name__ == "__main__":
127
  main()
 
1
 
2
+ # import streamlit as st
3
+ # import numpy as np
4
+ # import cv2
5
+ # import warnings
6
+ # import os
7
 
8
+ # # Suppress warnings
9
+ # warnings.filterwarnings("ignore", category=FutureWarning)
10
+ # warnings.filterwarnings("ignore", category=UserWarning)
11
+
12
+ # # Try importing TensorFlow
13
+ # try:
14
+ # from tensorflow.keras.models import load_model
15
+ # from tensorflow.keras.preprocessing import image
16
+ # except ImportError:
17
+ # st.error("Failed to import TensorFlow. Please make sure it's installed correctly.")
18
+
19
+ # # Try importing PyTorch and Detectron2
20
+ # try:
21
+ # import torch
22
+ # import detectron2
23
+ # except ImportError:
24
+ # with st.spinner("Installing PyTorch and Detectron2..."):
25
+ # os.system("pip install torch torchvision")
26
+ # os.system("pip install 'git+https://github.com/facebookresearch/detectron2.git'")
27
+
28
+ # import torch
29
+ # import detectron2
30
+
31
+
32
+ # import streamlit as st
33
+ # import numpy as np
34
+ # import cv2
35
+ # import torch
36
+ # import os
37
+ # from PIL import Image
38
+ # from tensorflow.keras.models import load_model
39
+ # from tensorflow.keras.preprocessing import image
40
+ # from detectron2.engine import DefaultPredictor
41
+ # from detectron2.config import get_cfg
42
+ # from detectron2.utils.visualizer import Visualizer
43
+ # from detectron2.data import MetadataCatalog
44
+
45
+ # # Suppress warnings
46
+ # import warnings
47
+ # import tensorflow as tf
48
+ # warnings.filterwarnings("ignore")
49
+ # tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
50
+
51
+ # @st.cache_resource
52
+ # def load_models():
53
+ # model_name = load_model('name_model_inception.h5')
54
+ # model_quality = load_model('type_model_inception.h5')
55
+ # return model_name, model_quality
56
+
57
+ # model_name, model_quality = load_models()
58
+
59
+ # # Detectron2 setup
60
+ # @st.cache_resource
61
+ # def load_detectron_model(fruit_name):
62
+ # cfg = get_cfg()
63
+ # config_path = os.path.join(f"{fruit_name.lower()}_config.yaml")
64
+ # cfg.merge_from_file(config_path)
65
+ # model_path = os.path.join(f"{fruit_name}_model.pth")
66
+ # cfg.MODEL.WEIGHTS = model_path
67
+ # cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
68
+ # cfg.MODEL.DEVICE = 'cpu'
69
+ # predictor = DefaultPredictor(cfg)
70
+ # return predictor, cfg
71
+
72
+ # # Labels
73
+ # label_map_name = {
74
+ # 0: "Banana", 1: "Cucumber", 2: "Grape", 3: "Kaki", 4: "Papaya",
75
+ # 5: "Peach", 6: "Pear", 7: "Peeper", 8: "Strawberry", 9: "Watermelon",
76
+ # 10: "tomato"
77
+ # }
78
+ # label_map_quality = {0: "Good", 1: "Mild", 2: "Rotten"}
79
+
80
+ # def predict_fruit(img):
81
+ # # Preprocess image
82
+ # img = Image.fromarray(img.astype('uint8'), 'RGB')
83
+ # img = img.resize((224, 224))
84
+ # x = image.img_to_array(img)
85
+ # x = np.expand_dims(x, axis=0)
86
+ # x = x / 255.0
87
+
88
+ # # Predict
89
+ # pred_name = model_name.predict(x)
90
+ # pred_quality = model_quality.predict(x)
91
+
92
+ # predicted_name = label_map_name[np.argmax(pred_name, axis=1)[0]]
93
+ # predicted_quality = label_map_quality[np.argmax(pred_quality, axis=1)[0]]
94
 
95
+ # return predicted_name, predicted_quality, img
96
+
97
+ # def main():
98
+ # st.title("Automated Fruits Monitoring System")
99
+ # st.write("Upload an image of a fruit to detect its type, quality, and potential damage.")
100
+
101
+ # uploaded_file = st.file_uploader("Choose a fruit image...", type=["jpg", "jpeg", "png"])
102
+
103
+ # if uploaded_file is not None:
104
+ # image = Image.open(uploaded_file)
105
+ # st.image(image, caption="Uploaded Image", use_column_width=True)
106
+
107
+ # if st.button("Analyze"):
108
+ # predicted_name, predicted_quality, img = predict_fruit(np.array(image))
109
+
110
+ # st.write(f"Fruits Type Detection: {predicted_name}")
111
+ # st.write(f"Fruits Quality Classification: {predicted_quality}")
112
+
113
+ # if predicted_name.lower() in ["kaki", "tomato", "strawberry", "peeper", "pear", "peach", "papaya", "watermelon", "grape", "banana", "cucumber"] and predicted_quality in ["Mild", "Rotten"]:
114
+ # st.write("Segmentation of Defective Region:")
115
+ # try:
116
+ # predictor, cfg = load_detectron_model(predicted_name)
117
+ # outputs = predictor(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))
118
+ # v = Visualizer(np.array(img), MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=0.8)
119
+ # out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
120
+ # st.image(out.get_image(), caption="Damage Detection Result", use_column_width=True)
121
+ # except Exception as e:
122
+ # st.error(f"Error in damage detection: {str(e)}")
123
+ # else:
124
+ # st.write("No damage detection performed for this fruit or quality level.")
125
+
126
+ # if __name__ == "__main__":
127
+ # main()
128
 
 
 
 
 
 
 
 
 
129
 
 
 
130
 
131
 
132
  import streamlit as st
133
  import numpy as np
134
  import cv2
135
+ import warnings
136
  import os
137
+ from pathlib import Path
138
  from PIL import Image
139
+ import tensorflow as tf
140
  from tensorflow.keras.models import load_model
141
  from tensorflow.keras.preprocessing import image
142
  from detectron2.engine import DefaultPredictor
 
145
  from detectron2.data import MetadataCatalog
146
 
147
  # Suppress warnings
 
 
148
  warnings.filterwarnings("ignore")
149
  tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
150
 
151
+ # Configuration
152
+ MODEL_CONFIG = {
153
+ 'name_model': 'name_model_inception.h5',
154
+ 'quality_model': 'type_model_inception.h5',
155
+ 'input_size': (224, 224),
156
+ 'score_threshold': 0.5
157
+ }
158
 
159
+ LABEL_MAPS = {
160
+ 'name': {
161
+ 0: "Banana", 1: "Cucumber", 2: "Grape", 3: "Kaki", 4: "Papaya",
162
+ 5: "Peach", 6: "Pear", 7: "Peeper", 8: "Strawberry", 9: "Watermelon",
163
+ 10: "tomato"
164
+ },
165
+ 'quality': {0: "Good", 1: "Mild", 2: "Rotten"}
166
+ }
167
 
 
168
  @st.cache_resource
169
+ def load_classification_models():
170
+ """Load and cache the classification models."""
171
+ try:
172
+ model_name = load_model(MODEL_CONFIG['name_model'])
173
+ model_quality = load_model(MODEL_CONFIG['quality_model'])
174
+ return model_name, model_quality
175
+ except Exception as e:
176
+ st.error(f"Error loading classification models: {str(e)}")
177
+ return None, None
 
 
 
 
 
 
 
 
 
178
 
179
+ @st.cache_resource
180
+ def load_detectron_model(fruit_name: str):
181
+ """Load and cache the Detectron2 model for damage detection."""
182
+ try:
183
+ cfg = get_cfg()
184
+ config_path = Path(f"{fruit_name.lower()}_config.yaml")
185
+ model_path = Path(f"{fruit_name}_model.pth")
186
+
187
+ if not config_path.exists() or not model_path.exists():
188
+ raise FileNotFoundError(f"Model files not found for {fruit_name}")
189
+
190
+ cfg.merge_from_file(str(config_path))
191
+ cfg.MODEL.WEIGHTS = str(model_path)
192
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = MODEL_CONFIG['score_threshold']
193
+ cfg.MODEL.DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
194
+
195
+ return DefaultPredictor(cfg), cfg
196
+ except Exception as e:
197
+ st.error(f"Error loading Detectron2 model: {str(e)}")
198
+ return None, None
199
 
200
+ def preprocess_image(img: np.ndarray) -> tuple:
201
+ """Preprocess the input image for model prediction."""
202
+ try:
203
+ # Convert to PIL Image if necessary
204
+ if isinstance(img, np.ndarray):
205
+ img = Image.fromarray(img.astype('uint8'), 'RGB')
206
+
207
+ # Resize and prepare for model input
208
+ img_resized = img.resize(MODEL_CONFIG['input_size'])
209
+ img_array = image.img_to_array(img_resized)
210
+ img_expanded = np.expand_dims(img_array, axis=0)
211
+ img_normalized = img_expanded / 255.0
212
+
213
+ return img_normalized, img_resized
214
+ except Exception as e:
215
+ st.error(f"Error preprocessing image: {str(e)}")
216
+ return None, None
217
 
218
+ def predict_fruit(img: np.ndarray) -> tuple:
219
+ """Predict fruit type and quality."""
220
+ model_name, model_quality = load_classification_models()
221
+ if model_name is None or model_quality is None:
222
+ return None, None, None
223
+
224
+ img_normalized, img_resized = preprocess_image(img)
225
+ if img_normalized is None:
226
+ return None, None, None
227
+
228
+ try:
229
+ # Make predictions
230
+ pred_name = model_name.predict(img_normalized)
231
+ pred_quality = model_quality.predict(img_normalized)
232
+
233
+ # Get predicted labels
234
+ predicted_name = LABEL_MAPS['name'][np.argmax(pred_name, axis=1)[0]]
235
+ predicted_quality = LABEL_MAPS['quality'][np.argmax(pred_quality, axis=1)[0]]
236
+
237
+ return predicted_name, predicted_quality, img_resized
238
+ except Exception as e:
239
+ st.error(f"Error making predictions: {str(e)}")
240
+ return None, None, None
241
 
242
+ def detect_damage(img: Image, fruit_name: str) -> np.ndarray:
243
+ """Detect and visualize damage in the fruit image."""
244
+ predictor, cfg = load_detectron_model(fruit_name)
245
+ if predictor is None or cfg is None:
246
+ return None
247
+
248
+ try:
249
+ outputs = predictor(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))
250
+ v = Visualizer(np.array(img), MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=0.8)
251
+ out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
252
+ return out.get_image()
253
+ except Exception as e:
254
+ st.error(f"Error in damage detection: {str(e)}")
255
+ return None
256
 
257
  def main():
258
+ st.set_page_config(page_title="Fruit Quality Analysis", layout="wide")
259
+
260
+ st.title("Automated Fruits Monitoring System")
261
  st.write("Upload an image of a fruit to detect its type, quality, and potential damage.")
262
+
263
  uploaded_file = st.file_uploader("Choose a fruit image...", type=["jpg", "jpeg", "png"])
264
+
265
  if uploaded_file is not None:
266
+ # Create two columns for layout
267
+ col1, col2 = st.columns(2)
268
+
269
+ # Display uploaded image
270
  image = Image.open(uploaded_file)
271
+ col1.image(image, caption="Uploaded Image", use_column_width=True)
272
+
273
+ if col1.button("Analyze"):
274
+ with st.spinner("Analyzing image..."):
275
+ predicted_name, predicted_quality, img_resized = predict_fruit(np.array(image))
276
+
277
+ if predicted_name and predicted_quality:
278
+ # Display results
279
+ col2.markdown("### Analysis Results")
280
+ col2.markdown(f"**Fruit Type:** {predicted_name}")
281
+ col2.markdown(f"**Quality:** {predicted_quality}")
282
+
283
+ # Check if damage detection is needed
284
+ if (predicted_name.lower() in LABEL_MAPS['name'].values() and
285
+ predicted_quality in ["Mild", "Rotten"]):
286
+
287
+ col2.markdown("### Damage Detection")
288
+ damage_image = detect_damage(img_resized, predicted_name)
289
+
290
+ if damage_image is not None:
291
+ col2.image(damage_image, caption="Detected Damage Regions",
292
+ use_column_width=True)
293
+
294
+ # Add download button for the damage detection result
295
+ col2.download_button(
296
+ label="Download Analysis Result",
297
+ data=cv2.imencode('.png', damage_image)[1].tobytes(),
298
+ file_name=f"{predicted_name}_damage_analysis.png",
299
+ mime="image/png"
300
+ )
301
 
302
  if __name__ == "__main__":
303
  main()