Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -29,275 +29,275 @@ except ImportError:
|
|
29 |
import detectron2
|
30 |
|
31 |
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
|
45 |
-
#
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
|
57 |
-
|
58 |
|
59 |
-
#
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
#
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
|
80 |
-
|
81 |
-
#
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
|
88 |
-
#
|
89 |
-
|
90 |
-
|
91 |
|
92 |
-
|
93 |
-
|
94 |
|
95 |
-
|
96 |
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
|
101 |
-
|
102 |
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
|
126 |
-
|
127 |
-
|
128 |
|
129 |
|
130 |
|
131 |
|
132 |
-
import streamlit as st
|
133 |
-
import numpy as np
|
134 |
-
import cv2
|
135 |
-
import warnings
|
136 |
-
import os
|
137 |
-
from pathlib import Path
|
138 |
-
from PIL import Image
|
139 |
-
import tensorflow as tf
|
140 |
-
from tensorflow.keras.models import load_model
|
141 |
-
from tensorflow.keras.preprocessing import image
|
142 |
-
from detectron2.engine import DefaultPredictor
|
143 |
-
from detectron2.config import get_cfg
|
144 |
-
from detectron2.utils.visualizer import Visualizer
|
145 |
-
from detectron2.data import MetadataCatalog
|
146 |
|
147 |
-
# Suppress warnings
|
148 |
-
warnings.filterwarnings("ignore")
|
149 |
-
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
|
150 |
|
151 |
-
# Configuration
|
152 |
-
MODEL_CONFIG = {
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
}
|
158 |
|
159 |
-
LABEL_MAPS = {
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
}
|
167 |
|
168 |
-
@st.cache_resource
|
169 |
-
def load_classification_models():
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
|
179 |
-
@st.cache_resource
|
180 |
-
def load_detectron_model(fruit_name: str):
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
|
187 |
-
|
188 |
-
|
189 |
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
def preprocess_image(img: np.ndarray) -> tuple:
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
def predict_fruit(img: np.ndarray) -> tuple:
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
def detect_damage(img: Image, fruit_name: str) -> np.ndarray:
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
|
257 |
-
def main():
|
258 |
-
|
259 |
|
260 |
-
|
261 |
-
|
262 |
|
263 |
-
|
264 |
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
|
287 |
-
|
288 |
-
|
289 |
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
|
302 |
-
if __name__ == "__main__":
|
303 |
-
|
|
|
29 |
import detectron2
|
30 |
|
31 |
|
32 |
+
import streamlit as st
|
33 |
+
import numpy as np
|
34 |
+
import cv2
|
35 |
+
import torch
|
36 |
+
import os
|
37 |
+
from PIL import Image
|
38 |
+
from tensorflow.keras.models import load_model
|
39 |
+
from tensorflow.keras.preprocessing import image
|
40 |
+
from detectron2.engine import DefaultPredictor
|
41 |
+
from detectron2.config import get_cfg
|
42 |
+
from detectron2.utils.visualizer import Visualizer
|
43 |
+
from detectron2.data import MetadataCatalog
|
44 |
|
45 |
+
# Suppress warnings
|
46 |
+
import warnings
|
47 |
+
import tensorflow as tf
|
48 |
+
warnings.filterwarnings("ignore")
|
49 |
+
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
|
50 |
|
51 |
+
@st.cache_resource
|
52 |
+
def load_models():
|
53 |
+
model_name = load_model('name_model_inception.h5')
|
54 |
+
model_quality = load_model('type_model_inception.h5')
|
55 |
+
return model_name, model_quality
|
56 |
|
57 |
+
model_name, model_quality = load_models()
|
58 |
|
59 |
+
# Detectron2 setup
|
60 |
+
@st.cache_resource
|
61 |
+
def load_detectron_model(fruit_name):
|
62 |
+
cfg = get_cfg()
|
63 |
+
config_path = os.path.join(f"{fruit_name.lower()}_config.yaml")
|
64 |
+
cfg.merge_from_file(config_path)
|
65 |
+
model_path = os.path.join(f"{fruit_name}_model.pth")
|
66 |
+
cfg.MODEL.WEIGHTS = model_path
|
67 |
+
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
|
68 |
+
cfg.MODEL.DEVICE = 'cpu'
|
69 |
+
predictor = DefaultPredictor(cfg)
|
70 |
+
return predictor, cfg
|
71 |
+
|
72 |
+
# Labels
|
73 |
+
label_map_name = {
|
74 |
+
0: "Banana", 1: "Cucumber", 2: "Grape", 3: "Kaki", 4: "Papaya",
|
75 |
+
5: "Peach", 6: "Pear", 7: "Peeper", 8: "Strawberry", 9: "Watermelon",
|
76 |
+
10: "tomato"
|
77 |
+
}
|
78 |
+
label_map_quality = {0: "Good", 1: "Mild", 2: "Rotten"}
|
79 |
|
80 |
+
def predict_fruit(img):
|
81 |
+
# Preprocess image
|
82 |
+
img = Image.fromarray(img.astype('uint8'), 'RGB')
|
83 |
+
img = img.resize((224, 224))
|
84 |
+
x = image.img_to_array(img)
|
85 |
+
x = np.expand_dims(x, axis=0)
|
86 |
+
x = x / 255.0
|
87 |
|
88 |
+
# Predict
|
89 |
+
pred_name = model_name.predict(x)
|
90 |
+
pred_quality = model_quality.predict(x)
|
91 |
|
92 |
+
predicted_name = label_map_name[np.argmax(pred_name, axis=1)[0]]
|
93 |
+
predicted_quality = label_map_quality[np.argmax(pred_quality, axis=1)[0]]
|
94 |
|
95 |
+
return predicted_name, predicted_quality, img
|
96 |
|
97 |
+
def main():
|
98 |
+
st.title("Automated Fruits Monitoring System")
|
99 |
+
st.write("Upload an image of a fruit to detect its type, quality, and potential damage.")
|
100 |
|
101 |
+
uploaded_file = st.file_uploader("Choose a fruit image...", type=["jpg", "jpeg", "png"])
|
102 |
|
103 |
+
if uploaded_file is not None:
|
104 |
+
image = Image.open(uploaded_file)
|
105 |
+
st.image(image, caption="Uploaded Image", use_column_width=True)
|
106 |
+
|
107 |
+
if st.button("Analyze"):
|
108 |
+
predicted_name, predicted_quality, img = predict_fruit(np.array(image))
|
109 |
+
|
110 |
+
st.write(f"Fruits Type Detection: {predicted_name}")
|
111 |
+
st.write(f"Fruits Quality Classification: {predicted_quality}")
|
112 |
+
|
113 |
+
if predicted_name.lower() in ["kaki", "tomato", "strawberry", "peeper", "pear", "peach", "papaya", "watermelon", "grape", "banana", "cucumber"] and predicted_quality in ["Mild", "Rotten"]:
|
114 |
+
st.write("Segmentation of Defective Region:")
|
115 |
+
try:
|
116 |
+
predictor, cfg = load_detectron_model(predicted_name)
|
117 |
+
outputs = predictor(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))
|
118 |
+
v = Visualizer(np.array(img), MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=0.8)
|
119 |
+
out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
|
120 |
+
st.image(out.get_image(), caption="Damage Detection Result", use_column_width=True)
|
121 |
+
except Exception as e:
|
122 |
+
st.error(f"Error in damage detection: {str(e)}")
|
123 |
+
else:
|
124 |
+
st.write("No damage detection performed for this fruit or quality level.")
|
125 |
|
126 |
+
if __name__ == "__main__":
|
127 |
+
main()
|
128 |
|
129 |
|
130 |
|
131 |
|
132 |
+
# import streamlit as st
|
133 |
+
# import numpy as np
|
134 |
+
# import cv2
|
135 |
+
# import warnings
|
136 |
+
# import os
|
137 |
+
# from pathlib import Path
|
138 |
+
# from PIL import Image
|
139 |
+
# import tensorflow as tf
|
140 |
+
# from tensorflow.keras.models import load_model
|
141 |
+
# from tensorflow.keras.preprocessing import image
|
142 |
+
# from detectron2.engine import DefaultPredictor
|
143 |
+
# from detectron2.config import get_cfg
|
144 |
+
# from detectron2.utils.visualizer import Visualizer
|
145 |
+
# from detectron2.data import MetadataCatalog
|
146 |
|
147 |
+
# # Suppress warnings
|
148 |
+
# warnings.filterwarnings("ignore")
|
149 |
+
# tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
|
150 |
|
151 |
+
# # Configuration
|
152 |
+
# MODEL_CONFIG = {
|
153 |
+
# 'name_model': 'name_model_inception.h5',
|
154 |
+
# 'quality_model': 'type_model_inception.h5',
|
155 |
+
# 'input_size': (224, 224),
|
156 |
+
# 'score_threshold': 0.5
|
157 |
+
# }
|
158 |
|
159 |
+
# LABEL_MAPS = {
|
160 |
+
# 'name': {
|
161 |
+
# 0: "Banana", 1: "Cucumber", 2: "Grape", 3: "Kaki", 4: "Papaya",
|
162 |
+
# 5: "Peach", 6: "Pear", 7: "Peeper", 8: "Strawberry", 9: "Watermelon",
|
163 |
+
# 10: "tomato"
|
164 |
+
# },
|
165 |
+
# 'quality': {0: "Good", 1: "Mild", 2: "Rotten"}
|
166 |
+
# }
|
167 |
|
168 |
+
# @st.cache_resource
|
169 |
+
# def load_classification_models():
|
170 |
+
# """Load and cache the classification models."""
|
171 |
+
# try:
|
172 |
+
# model_name = load_model(MODEL_CONFIG['name_model'])
|
173 |
+
# model_quality = load_model(MODEL_CONFIG['quality_model'])
|
174 |
+
# return model_name, model_quality
|
175 |
+
# except Exception as e:
|
176 |
+
# st.error(f"Error loading classification models: {str(e)}")
|
177 |
+
# return None, None
|
178 |
|
179 |
+
# @st.cache_resource
|
180 |
+
# def load_detectron_model(fruit_name: str):
|
181 |
+
# """Load and cache the Detectron2 model for damage detection."""
|
182 |
+
# try:
|
183 |
+
# cfg = get_cfg()
|
184 |
+
# config_path = Path(f"{fruit_name.lower()}_config.yaml")
|
185 |
+
# model_path = Path(f"{fruit_name}_model.pth")
|
186 |
|
187 |
+
# if not config_path.exists() or not model_path.exists():
|
188 |
+
# raise FileNotFoundError(f"Model files not found for {fruit_name}")
|
189 |
|
190 |
+
# cfg.merge_from_file(str(config_path))
|
191 |
+
# cfg.MODEL.WEIGHTS = str(model_path)
|
192 |
+
# cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = MODEL_CONFIG['score_threshold']
|
193 |
+
# cfg.MODEL.DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
194 |
|
195 |
+
# return DefaultPredictor(cfg), cfg
|
196 |
+
# except Exception as e:
|
197 |
+
# st.error(f"Error loading Detectron2 model: {str(e)}")
|
198 |
+
# return None, None
|
199 |
+
|
200 |
+
# def preprocess_image(img: np.ndarray) -> tuple:
|
201 |
+
# """Preprocess the input image for model prediction."""
|
202 |
+
# try:
|
203 |
+
# # Convert to PIL Image if necessary
|
204 |
+
# if isinstance(img, np.ndarray):
|
205 |
+
# img = Image.fromarray(img.astype('uint8'), 'RGB')
|
206 |
|
207 |
+
# # Resize and prepare for model input
|
208 |
+
# img_resized = img.resize(MODEL_CONFIG['input_size'])
|
209 |
+
# img_array = image.img_to_array(img_resized)
|
210 |
+
# img_expanded = np.expand_dims(img_array, axis=0)
|
211 |
+
# img_normalized = img_expanded / 255.0
|
212 |
|
213 |
+
# return img_normalized, img_resized
|
214 |
+
# except Exception as e:
|
215 |
+
# st.error(f"Error preprocessing image: {str(e)}")
|
216 |
+
# return None, None
|
217 |
+
|
218 |
+
# def predict_fruit(img: np.ndarray) -> tuple:
|
219 |
+
# """Predict fruit type and quality."""
|
220 |
+
# model_name, model_quality = load_classification_models()
|
221 |
+
# if model_name is None or model_quality is None:
|
222 |
+
# return None, None, None
|
223 |
|
224 |
+
# img_normalized, img_resized = preprocess_image(img)
|
225 |
+
# if img_normalized is None:
|
226 |
+
# return None, None, None
|
227 |
|
228 |
+
# try:
|
229 |
+
# # Make predictions
|
230 |
+
# pred_name = model_name.predict(img_normalized)
|
231 |
+
# pred_quality = model_quality.predict(img_normalized)
|
232 |
|
233 |
+
# # Get predicted labels
|
234 |
+
# predicted_name = LABEL_MAPS['name'][np.argmax(pred_name, axis=1)[0]]
|
235 |
+
# predicted_quality = LABEL_MAPS['quality'][np.argmax(pred_quality, axis=1)[0]]
|
236 |
|
237 |
+
# return predicted_name, predicted_quality, img_resized
|
238 |
+
# except Exception as e:
|
239 |
+
# st.error(f"Error making predictions: {str(e)}")
|
240 |
+
# return None, None, None
|
241 |
+
|
242 |
+
# def detect_damage(img: Image, fruit_name: str) -> np.ndarray:
|
243 |
+
# """Detect and visualize damage in the fruit image."""
|
244 |
+
# predictor, cfg = load_detectron_model(fruit_name)
|
245 |
+
# if predictor is None or cfg is None:
|
246 |
+
# return None
|
247 |
|
248 |
+
# try:
|
249 |
+
# outputs = predictor(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))
|
250 |
+
# v = Visualizer(np.array(img), MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=0.8)
|
251 |
+
# out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
|
252 |
+
# return out.get_image()
|
253 |
+
# except Exception as e:
|
254 |
+
# st.error(f"Error in damage detection: {str(e)}")
|
255 |
+
# return None
|
256 |
|
257 |
+
# def main():
|
258 |
+
# st.set_page_config(page_title="Fruit Quality Analysis", layout="wide")
|
259 |
|
260 |
+
# st.title("Automated Fruits Monitoring System")
|
261 |
+
# st.write("Upload an image of a fruit to detect its type, quality, and potential damage.")
|
262 |
|
263 |
+
# uploaded_file = st.file_uploader("Choose a fruit image...", type=["jpg", "jpeg", "png"])
|
264 |
|
265 |
+
# if uploaded_file is not None:
|
266 |
+
# # Create two columns for layout
|
267 |
+
# col1, col2 = st.columns(2)
|
268 |
|
269 |
+
# # Display uploaded image
|
270 |
+
# image = Image.open(uploaded_file)
|
271 |
+
# col1.image(image, caption="Uploaded Image", use_column_width=True)
|
272 |
|
273 |
+
# if col1.button("Analyze"):
|
274 |
+
# with st.spinner("Analyzing image..."):
|
275 |
+
# predicted_name, predicted_quality, img_resized = predict_fruit(np.array(image))
|
276 |
|
277 |
+
# if predicted_name and predicted_quality:
|
278 |
+
# # Display results
|
279 |
+
# col2.markdown("### Analysis Results")
|
280 |
+
# col2.markdown(f"**Fruit Type:** {predicted_name}")
|
281 |
+
# col2.markdown(f"**Quality:** {predicted_quality}")
|
282 |
|
283 |
+
# # Check if damage detection is needed
|
284 |
+
# if (predicted_name.lower() in LABEL_MAPS['name'].values() and
|
285 |
+
# predicted_quality in ["Mild", "Rotten"]):
|
286 |
|
287 |
+
# col2.markdown("### Damage Detection")
|
288 |
+
# damage_image = detect_damage(img_resized, predicted_name)
|
289 |
|
290 |
+
# if damage_image is not None:
|
291 |
+
# col2.image(damage_image, caption="Detected Damage Regions",
|
292 |
+
# use_column_width=True)
|
293 |
|
294 |
+
# # Add download button for the damage detection result
|
295 |
+
# col2.download_button(
|
296 |
+
# label="Download Analysis Result",
|
297 |
+
# data=cv2.imencode('.png', damage_image)[1].tobytes(),
|
298 |
+
# file_name=f"{predicted_name}_damage_analysis.png",
|
299 |
+
# mime="image/png"
|
300 |
+
# )
|
301 |
|
302 |
+
# if __name__ == "__main__":
|
303 |
+
# main()
|