MuhammmadRizwanRizwan commited on
Commit
52ed7df
·
verified ·
1 Parent(s): 408a8d8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +229 -229
app.py CHANGED
@@ -29,275 +29,275 @@ except ImportError:
29
  import detectron2
30
 
31
 
32
- # import streamlit as st
33
- # import numpy as np
34
- # import cv2
35
- # import torch
36
- # import os
37
- # from PIL import Image
38
- # from tensorflow.keras.models import load_model
39
- # from tensorflow.keras.preprocessing import image
40
- # from detectron2.engine import DefaultPredictor
41
- # from detectron2.config import get_cfg
42
- # from detectron2.utils.visualizer import Visualizer
43
- # from detectron2.data import MetadataCatalog
44
 
45
- # # Suppress warnings
46
- # import warnings
47
- # import tensorflow as tf
48
- # warnings.filterwarnings("ignore")
49
- # tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
50
 
51
- # @st.cache_resource
52
- # def load_models():
53
- # model_name = load_model('name_model_inception.h5')
54
- # model_quality = load_model('type_model_inception.h5')
55
- # return model_name, model_quality
56
 
57
- # model_name, model_quality = load_models()
58
 
59
- # # Detectron2 setup
60
- # @st.cache_resource
61
- # def load_detectron_model(fruit_name):
62
- # cfg = get_cfg()
63
- # config_path = os.path.join(f"{fruit_name.lower()}_config.yaml")
64
- # cfg.merge_from_file(config_path)
65
- # model_path = os.path.join(f"{fruit_name}_model.pth")
66
- # cfg.MODEL.WEIGHTS = model_path
67
- # cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
68
- # cfg.MODEL.DEVICE = 'cpu'
69
- # predictor = DefaultPredictor(cfg)
70
- # return predictor, cfg
71
-
72
- # # Labels
73
- # label_map_name = {
74
- # 0: "Banana", 1: "Cucumber", 2: "Grape", 3: "Kaki", 4: "Papaya",
75
- # 5: "Peach", 6: "Pear", 7: "Peeper", 8: "Strawberry", 9: "Watermelon",
76
- # 10: "tomato"
77
- # }
78
- # label_map_quality = {0: "Good", 1: "Mild", 2: "Rotten"}
79
 
80
- # def predict_fruit(img):
81
- # # Preprocess image
82
- # img = Image.fromarray(img.astype('uint8'), 'RGB')
83
- # img = img.resize((224, 224))
84
- # x = image.img_to_array(img)
85
- # x = np.expand_dims(x, axis=0)
86
- # x = x / 255.0
87
 
88
- # # Predict
89
- # pred_name = model_name.predict(x)
90
- # pred_quality = model_quality.predict(x)
91
 
92
- # predicted_name = label_map_name[np.argmax(pred_name, axis=1)[0]]
93
- # predicted_quality = label_map_quality[np.argmax(pred_quality, axis=1)[0]]
94
 
95
- # return predicted_name, predicted_quality, img
96
 
97
- # def main():
98
- # st.title("Automated Fruits Monitoring System")
99
- # st.write("Upload an image of a fruit to detect its type, quality, and potential damage.")
100
 
101
- # uploaded_file = st.file_uploader("Choose a fruit image...", type=["jpg", "jpeg", "png"])
102
 
103
- # if uploaded_file is not None:
104
- # image = Image.open(uploaded_file)
105
- # st.image(image, caption="Uploaded Image", use_column_width=True)
106
-
107
- # if st.button("Analyze"):
108
- # predicted_name, predicted_quality, img = predict_fruit(np.array(image))
109
-
110
- # st.write(f"Fruits Type Detection: {predicted_name}")
111
- # st.write(f"Fruits Quality Classification: {predicted_quality}")
112
-
113
- # if predicted_name.lower() in ["kaki", "tomato", "strawberry", "peeper", "pear", "peach", "papaya", "watermelon", "grape", "banana", "cucumber"] and predicted_quality in ["Mild", "Rotten"]:
114
- # st.write("Segmentation of Defective Region:")
115
- # try:
116
- # predictor, cfg = load_detectron_model(predicted_name)
117
- # outputs = predictor(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))
118
- # v = Visualizer(np.array(img), MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=0.8)
119
- # out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
120
- # st.image(out.get_image(), caption="Damage Detection Result", use_column_width=True)
121
- # except Exception as e:
122
- # st.error(f"Error in damage detection: {str(e)}")
123
- # else:
124
- # st.write("No damage detection performed for this fruit or quality level.")
125
 
126
- # if __name__ == "__main__":
127
- # main()
128
 
129
 
130
 
131
 
132
- import streamlit as st
133
- import numpy as np
134
- import cv2
135
- import warnings
136
- import os
137
- from pathlib import Path
138
- from PIL import Image
139
- import tensorflow as tf
140
- from tensorflow.keras.models import load_model
141
- from tensorflow.keras.preprocessing import image
142
- from detectron2.engine import DefaultPredictor
143
- from detectron2.config import get_cfg
144
- from detectron2.utils.visualizer import Visualizer
145
- from detectron2.data import MetadataCatalog
146
 
147
- # Suppress warnings
148
- warnings.filterwarnings("ignore")
149
- tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
150
 
151
- # Configuration
152
- MODEL_CONFIG = {
153
- 'name_model': 'name_model_inception.h5',
154
- 'quality_model': 'type_model_inception.h5',
155
- 'input_size': (224, 224),
156
- 'score_threshold': 0.5
157
- }
158
 
159
- LABEL_MAPS = {
160
- 'name': {
161
- 0: "Banana", 1: "Cucumber", 2: "Grape", 3: "Kaki", 4: "Papaya",
162
- 5: "Peach", 6: "Pear", 7: "Peeper", 8: "Strawberry", 9: "Watermelon",
163
- 10: "tomato"
164
- },
165
- 'quality': {0: "Good", 1: "Mild", 2: "Rotten"}
166
- }
167
 
168
- @st.cache_resource
169
- def load_classification_models():
170
- """Load and cache the classification models."""
171
- try:
172
- model_name = load_model(MODEL_CONFIG['name_model'])
173
- model_quality = load_model(MODEL_CONFIG['quality_model'])
174
- return model_name, model_quality
175
- except Exception as e:
176
- st.error(f"Error loading classification models: {str(e)}")
177
- return None, None
178
 
179
- @st.cache_resource
180
- def load_detectron_model(fruit_name: str):
181
- """Load and cache the Detectron2 model for damage detection."""
182
- try:
183
- cfg = get_cfg()
184
- config_path = Path(f"{fruit_name.lower()}_config.yaml")
185
- model_path = Path(f"{fruit_name}_model.pth")
186
 
187
- if not config_path.exists() or not model_path.exists():
188
- raise FileNotFoundError(f"Model files not found for {fruit_name}")
189
 
190
- cfg.merge_from_file(str(config_path))
191
- cfg.MODEL.WEIGHTS = str(model_path)
192
- cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = MODEL_CONFIG['score_threshold']
193
- cfg.MODEL.DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
194
 
195
- return DefaultPredictor(cfg), cfg
196
- except Exception as e:
197
- st.error(f"Error loading Detectron2 model: {str(e)}")
198
- return None, None
199
-
200
- def preprocess_image(img: np.ndarray) -> tuple:
201
- """Preprocess the input image for model prediction."""
202
- try:
203
- # Convert to PIL Image if necessary
204
- if isinstance(img, np.ndarray):
205
- img = Image.fromarray(img.astype('uint8'), 'RGB')
206
 
207
- # Resize and prepare for model input
208
- img_resized = img.resize(MODEL_CONFIG['input_size'])
209
- img_array = image.img_to_array(img_resized)
210
- img_expanded = np.expand_dims(img_array, axis=0)
211
- img_normalized = img_expanded / 255.0
212
 
213
- return img_normalized, img_resized
214
- except Exception as e:
215
- st.error(f"Error preprocessing image: {str(e)}")
216
- return None, None
217
-
218
- def predict_fruit(img: np.ndarray) -> tuple:
219
- """Predict fruit type and quality."""
220
- model_name, model_quality = load_classification_models()
221
- if model_name is None or model_quality is None:
222
- return None, None, None
223
 
224
- img_normalized, img_resized = preprocess_image(img)
225
- if img_normalized is None:
226
- return None, None, None
227
 
228
- try:
229
- # Make predictions
230
- pred_name = model_name.predict(img_normalized)
231
- pred_quality = model_quality.predict(img_normalized)
232
 
233
- # Get predicted labels
234
- predicted_name = LABEL_MAPS['name'][np.argmax(pred_name, axis=1)[0]]
235
- predicted_quality = LABEL_MAPS['quality'][np.argmax(pred_quality, axis=1)[0]]
236
 
237
- return predicted_name, predicted_quality, img_resized
238
- except Exception as e:
239
- st.error(f"Error making predictions: {str(e)}")
240
- return None, None, None
241
-
242
- def detect_damage(img: Image, fruit_name: str) -> np.ndarray:
243
- """Detect and visualize damage in the fruit image."""
244
- predictor, cfg = load_detectron_model(fruit_name)
245
- if predictor is None or cfg is None:
246
- return None
247
 
248
- try:
249
- outputs = predictor(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))
250
- v = Visualizer(np.array(img), MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=0.8)
251
- out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
252
- return out.get_image()
253
- except Exception as e:
254
- st.error(f"Error in damage detection: {str(e)}")
255
- return None
256
 
257
- def main():
258
- st.set_page_config(page_title="Fruit Quality Analysis", layout="wide")
259
 
260
- st.title("Automated Fruits Monitoring System")
261
- st.write("Upload an image of a fruit to detect its type, quality, and potential damage.")
262
 
263
- uploaded_file = st.file_uploader("Choose a fruit image...", type=["jpg", "jpeg", "png"])
264
 
265
- if uploaded_file is not None:
266
- # Create two columns for layout
267
- col1, col2 = st.columns(2)
268
 
269
- # Display uploaded image
270
- image = Image.open(uploaded_file)
271
- col1.image(image, caption="Uploaded Image", use_column_width=True)
272
 
273
- if col1.button("Analyze"):
274
- with st.spinner("Analyzing image..."):
275
- predicted_name, predicted_quality, img_resized = predict_fruit(np.array(image))
276
 
277
- if predicted_name and predicted_quality:
278
- # Display results
279
- col2.markdown("### Analysis Results")
280
- col2.markdown(f"**Fruit Type:** {predicted_name}")
281
- col2.markdown(f"**Quality:** {predicted_quality}")
282
 
283
- # Check if damage detection is needed
284
- if (predicted_name.lower() in LABEL_MAPS['name'].values() and
285
- predicted_quality in ["Mild", "Rotten"]):
286
 
287
- col2.markdown("### Damage Detection")
288
- damage_image = detect_damage(img_resized, predicted_name)
289
 
290
- if damage_image is not None:
291
- col2.image(damage_image, caption="Detected Damage Regions",
292
- use_column_width=True)
293
 
294
- # Add download button for the damage detection result
295
- col2.download_button(
296
- label="Download Analysis Result",
297
- data=cv2.imencode('.png', damage_image)[1].tobytes(),
298
- file_name=f"{predicted_name}_damage_analysis.png",
299
- mime="image/png"
300
- )
301
 
302
- if __name__ == "__main__":
303
- main()
 
29
  import detectron2
30
 
31
 
32
+ import streamlit as st
33
+ import numpy as np
34
+ import cv2
35
+ import torch
36
+ import os
37
+ from PIL import Image
38
+ from tensorflow.keras.models import load_model
39
+ from tensorflow.keras.preprocessing import image
40
+ from detectron2.engine import DefaultPredictor
41
+ from detectron2.config import get_cfg
42
+ from detectron2.utils.visualizer import Visualizer
43
+ from detectron2.data import MetadataCatalog
44
 
45
+ # Suppress warnings
46
+ import warnings
47
+ import tensorflow as tf
48
+ warnings.filterwarnings("ignore")
49
+ tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
50
 
51
+ @st.cache_resource
52
+ def load_models():
53
+ model_name = load_model('name_model_inception.h5')
54
+ model_quality = load_model('type_model_inception.h5')
55
+ return model_name, model_quality
56
 
57
+ model_name, model_quality = load_models()
58
 
59
+ # Detectron2 setup
60
+ @st.cache_resource
61
+ def load_detectron_model(fruit_name):
62
+ cfg = get_cfg()
63
+ config_path = os.path.join(f"{fruit_name.lower()}_config.yaml")
64
+ cfg.merge_from_file(config_path)
65
+ model_path = os.path.join(f"{fruit_name}_model.pth")
66
+ cfg.MODEL.WEIGHTS = model_path
67
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
68
+ cfg.MODEL.DEVICE = 'cpu'
69
+ predictor = DefaultPredictor(cfg)
70
+ return predictor, cfg
71
+
72
+ # Labels
73
+ label_map_name = {
74
+ 0: "Banana", 1: "Cucumber", 2: "Grape", 3: "Kaki", 4: "Papaya",
75
+ 5: "Peach", 6: "Pear", 7: "Peeper", 8: "Strawberry", 9: "Watermelon",
76
+ 10: "tomato"
77
+ }
78
+ label_map_quality = {0: "Good", 1: "Mild", 2: "Rotten"}
79
 
80
+ def predict_fruit(img):
81
+ # Preprocess image
82
+ img = Image.fromarray(img.astype('uint8'), 'RGB')
83
+ img = img.resize((224, 224))
84
+ x = image.img_to_array(img)
85
+ x = np.expand_dims(x, axis=0)
86
+ x = x / 255.0
87
 
88
+ # Predict
89
+ pred_name = model_name.predict(x)
90
+ pred_quality = model_quality.predict(x)
91
 
92
+ predicted_name = label_map_name[np.argmax(pred_name, axis=1)[0]]
93
+ predicted_quality = label_map_quality[np.argmax(pred_quality, axis=1)[0]]
94
 
95
+ return predicted_name, predicted_quality, img
96
 
97
+ def main():
98
+ st.title("Automated Fruits Monitoring System")
99
+ st.write("Upload an image of a fruit to detect its type, quality, and potential damage.")
100
 
101
+ uploaded_file = st.file_uploader("Choose a fruit image...", type=["jpg", "jpeg", "png"])
102
 
103
+ if uploaded_file is not None:
104
+ image = Image.open(uploaded_file)
105
+ st.image(image, caption="Uploaded Image", use_column_width=True)
106
+
107
+ if st.button("Analyze"):
108
+ predicted_name, predicted_quality, img = predict_fruit(np.array(image))
109
+
110
+ st.write(f"Fruits Type Detection: {predicted_name}")
111
+ st.write(f"Fruits Quality Classification: {predicted_quality}")
112
+
113
+ if predicted_name.lower() in ["kaki", "tomato", "strawberry", "peeper", "pear", "peach", "papaya", "watermelon", "grape", "banana", "cucumber"] and predicted_quality in ["Mild", "Rotten"]:
114
+ st.write("Segmentation of Defective Region:")
115
+ try:
116
+ predictor, cfg = load_detectron_model(predicted_name)
117
+ outputs = predictor(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))
118
+ v = Visualizer(np.array(img), MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=0.8)
119
+ out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
120
+ st.image(out.get_image(), caption="Damage Detection Result", use_column_width=True)
121
+ except Exception as e:
122
+ st.error(f"Error in damage detection: {str(e)}")
123
+ else:
124
+ st.write("No damage detection performed for this fruit or quality level.")
125
 
126
+ if __name__ == "__main__":
127
+ main()
128
 
129
 
130
 
131
 
132
+ # import streamlit as st
133
+ # import numpy as np
134
+ # import cv2
135
+ # import warnings
136
+ # import os
137
+ # from pathlib import Path
138
+ # from PIL import Image
139
+ # import tensorflow as tf
140
+ # from tensorflow.keras.models import load_model
141
+ # from tensorflow.keras.preprocessing import image
142
+ # from detectron2.engine import DefaultPredictor
143
+ # from detectron2.config import get_cfg
144
+ # from detectron2.utils.visualizer import Visualizer
145
+ # from detectron2.data import MetadataCatalog
146
 
147
+ # # Suppress warnings
148
+ # warnings.filterwarnings("ignore")
149
+ # tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
150
 
151
+ # # Configuration
152
+ # MODEL_CONFIG = {
153
+ # 'name_model': 'name_model_inception.h5',
154
+ # 'quality_model': 'type_model_inception.h5',
155
+ # 'input_size': (224, 224),
156
+ # 'score_threshold': 0.5
157
+ # }
158
 
159
+ # LABEL_MAPS = {
160
+ # 'name': {
161
+ # 0: "Banana", 1: "Cucumber", 2: "Grape", 3: "Kaki", 4: "Papaya",
162
+ # 5: "Peach", 6: "Pear", 7: "Peeper", 8: "Strawberry", 9: "Watermelon",
163
+ # 10: "tomato"
164
+ # },
165
+ # 'quality': {0: "Good", 1: "Mild", 2: "Rotten"}
166
+ # }
167
 
168
+ # @st.cache_resource
169
+ # def load_classification_models():
170
+ # """Load and cache the classification models."""
171
+ # try:
172
+ # model_name = load_model(MODEL_CONFIG['name_model'])
173
+ # model_quality = load_model(MODEL_CONFIG['quality_model'])
174
+ # return model_name, model_quality
175
+ # except Exception as e:
176
+ # st.error(f"Error loading classification models: {str(e)}")
177
+ # return None, None
178
 
179
+ # @st.cache_resource
180
+ # def load_detectron_model(fruit_name: str):
181
+ # """Load and cache the Detectron2 model for damage detection."""
182
+ # try:
183
+ # cfg = get_cfg()
184
+ # config_path = Path(f"{fruit_name.lower()}_config.yaml")
185
+ # model_path = Path(f"{fruit_name}_model.pth")
186
 
187
+ # if not config_path.exists() or not model_path.exists():
188
+ # raise FileNotFoundError(f"Model files not found for {fruit_name}")
189
 
190
+ # cfg.merge_from_file(str(config_path))
191
+ # cfg.MODEL.WEIGHTS = str(model_path)
192
+ # cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = MODEL_CONFIG['score_threshold']
193
+ # cfg.MODEL.DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
194
 
195
+ # return DefaultPredictor(cfg), cfg
196
+ # except Exception as e:
197
+ # st.error(f"Error loading Detectron2 model: {str(e)}")
198
+ # return None, None
199
+
200
+ # def preprocess_image(img: np.ndarray) -> tuple:
201
+ # """Preprocess the input image for model prediction."""
202
+ # try:
203
+ # # Convert to PIL Image if necessary
204
+ # if isinstance(img, np.ndarray):
205
+ # img = Image.fromarray(img.astype('uint8'), 'RGB')
206
 
207
+ # # Resize and prepare for model input
208
+ # img_resized = img.resize(MODEL_CONFIG['input_size'])
209
+ # img_array = image.img_to_array(img_resized)
210
+ # img_expanded = np.expand_dims(img_array, axis=0)
211
+ # img_normalized = img_expanded / 255.0
212
 
213
+ # return img_normalized, img_resized
214
+ # except Exception as e:
215
+ # st.error(f"Error preprocessing image: {str(e)}")
216
+ # return None, None
217
+
218
+ # def predict_fruit(img: np.ndarray) -> tuple:
219
+ # """Predict fruit type and quality."""
220
+ # model_name, model_quality = load_classification_models()
221
+ # if model_name is None or model_quality is None:
222
+ # return None, None, None
223
 
224
+ # img_normalized, img_resized = preprocess_image(img)
225
+ # if img_normalized is None:
226
+ # return None, None, None
227
 
228
+ # try:
229
+ # # Make predictions
230
+ # pred_name = model_name.predict(img_normalized)
231
+ # pred_quality = model_quality.predict(img_normalized)
232
 
233
+ # # Get predicted labels
234
+ # predicted_name = LABEL_MAPS['name'][np.argmax(pred_name, axis=1)[0]]
235
+ # predicted_quality = LABEL_MAPS['quality'][np.argmax(pred_quality, axis=1)[0]]
236
 
237
+ # return predicted_name, predicted_quality, img_resized
238
+ # except Exception as e:
239
+ # st.error(f"Error making predictions: {str(e)}")
240
+ # return None, None, None
241
+
242
+ # def detect_damage(img: Image, fruit_name: str) -> np.ndarray:
243
+ # """Detect and visualize damage in the fruit image."""
244
+ # predictor, cfg = load_detectron_model(fruit_name)
245
+ # if predictor is None or cfg is None:
246
+ # return None
247
 
248
+ # try:
249
+ # outputs = predictor(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))
250
+ # v = Visualizer(np.array(img), MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=0.8)
251
+ # out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
252
+ # return out.get_image()
253
+ # except Exception as e:
254
+ # st.error(f"Error in damage detection: {str(e)}")
255
+ # return None
256
 
257
+ # def main():
258
+ # st.set_page_config(page_title="Fruit Quality Analysis", layout="wide")
259
 
260
+ # st.title("Automated Fruits Monitoring System")
261
+ # st.write("Upload an image of a fruit to detect its type, quality, and potential damage.")
262
 
263
+ # uploaded_file = st.file_uploader("Choose a fruit image...", type=["jpg", "jpeg", "png"])
264
 
265
+ # if uploaded_file is not None:
266
+ # # Create two columns for layout
267
+ # col1, col2 = st.columns(2)
268
 
269
+ # # Display uploaded image
270
+ # image = Image.open(uploaded_file)
271
+ # col1.image(image, caption="Uploaded Image", use_column_width=True)
272
 
273
+ # if col1.button("Analyze"):
274
+ # with st.spinner("Analyzing image..."):
275
+ # predicted_name, predicted_quality, img_resized = predict_fruit(np.array(image))
276
 
277
+ # if predicted_name and predicted_quality:
278
+ # # Display results
279
+ # col2.markdown("### Analysis Results")
280
+ # col2.markdown(f"**Fruit Type:** {predicted_name}")
281
+ # col2.markdown(f"**Quality:** {predicted_quality}")
282
 
283
+ # # Check if damage detection is needed
284
+ # if (predicted_name.lower() in LABEL_MAPS['name'].values() and
285
+ # predicted_quality in ["Mild", "Rotten"]):
286
 
287
+ # col2.markdown("### Damage Detection")
288
+ # damage_image = detect_damage(img_resized, predicted_name)
289
 
290
+ # if damage_image is not None:
291
+ # col2.image(damage_image, caption="Detected Damage Regions",
292
+ # use_column_width=True)
293
 
294
+ # # Add download button for the damage detection result
295
+ # col2.download_button(
296
+ # label="Download Analysis Result",
297
+ # data=cv2.imencode('.png', damage_image)[1].tobytes(),
298
+ # file_name=f"{predicted_name}_damage_analysis.png",
299
+ # mime="image/png"
300
+ # )
301
 
302
+ # if __name__ == "__main__":
303
+ # main()