Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,40 +1,142 @@
|
|
1 |
|
2 |
-
import streamlit as st
|
3 |
-
import numpy as np
|
4 |
-
import cv2
|
5 |
-
import warnings
|
6 |
-
import os
|
7 |
|
8 |
-
# Suppress warnings
|
9 |
-
warnings.filterwarnings("ignore", category=FutureWarning)
|
10 |
-
warnings.filterwarnings("ignore", category=UserWarning)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
-
#
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
-
# Try importing PyTorch and Detectron2
|
20 |
-
try:
|
21 |
-
import torch
|
22 |
-
import detectron2
|
23 |
-
except ImportError:
|
24 |
-
with st.spinner("Installing PyTorch and Detectron2..."):
|
25 |
-
os.system("pip install torch torchvision")
|
26 |
-
os.system("pip install 'git+https://github.com/facebookresearch/detectron2.git'")
|
27 |
|
28 |
-
import torch
|
29 |
-
import detectron2
|
30 |
|
31 |
|
32 |
import streamlit as st
|
33 |
import numpy as np
|
34 |
import cv2
|
35 |
-
import
|
36 |
import os
|
|
|
37 |
from PIL import Image
|
|
|
38 |
from tensorflow.keras.models import load_model
|
39 |
from tensorflow.keras.preprocessing import image
|
40 |
from detectron2.engine import DefaultPredictor
|
@@ -43,85 +145,159 @@ from detectron2.utils.visualizer import Visualizer
|
|
43 |
from detectron2.data import MetadataCatalog
|
44 |
|
45 |
# Suppress warnings
|
46 |
-
import warnings
|
47 |
-
import tensorflow as tf
|
48 |
warnings.filterwarnings("ignore")
|
49 |
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
|
50 |
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
|
|
|
|
56 |
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
-
# Detectron2 setup
|
60 |
@st.cache_resource
|
61 |
-
def
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
return predictor, cfg
|
71 |
-
|
72 |
-
# Labels
|
73 |
-
label_map_name = {
|
74 |
-
0: "Banana", 1: "Cucumber", 2: "Grape", 3: "Kaki", 4: "Papaya",
|
75 |
-
5: "Peach", 6: "Pear", 7: "Peeper", 8: "Strawberry", 9: "Watermelon",
|
76 |
-
10: "tomato"
|
77 |
-
}
|
78 |
-
label_map_quality = {0: "Good", 1: "Mild", 2: "Rotten"}
|
79 |
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
|
88 |
-
|
89 |
-
|
90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
|
92 |
-
|
93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
|
97 |
def main():
|
98 |
-
st.
|
|
|
|
|
99 |
st.write("Upload an image of a fruit to detect its type, quality, and potential damage.")
|
100 |
-
|
101 |
uploaded_file = st.file_uploader("Choose a fruit image...", type=["jpg", "jpeg", "png"])
|
102 |
-
|
103 |
if uploaded_file is not None:
|
|
|
|
|
|
|
|
|
104 |
image = Image.open(uploaded_file)
|
105 |
-
|
106 |
-
|
107 |
-
if
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
|
126 |
if __name__ == "__main__":
|
127 |
main()
|
|
|
1 |
|
2 |
+
# import streamlit as st
|
3 |
+
# import numpy as np
|
4 |
+
# import cv2
|
5 |
+
# import warnings
|
6 |
+
# import os
|
7 |
|
8 |
+
# # Suppress warnings
|
9 |
+
# warnings.filterwarnings("ignore", category=FutureWarning)
|
10 |
+
# warnings.filterwarnings("ignore", category=UserWarning)
|
11 |
+
|
12 |
+
# # Try importing TensorFlow
|
13 |
+
# try:
|
14 |
+
# from tensorflow.keras.models import load_model
|
15 |
+
# from tensorflow.keras.preprocessing import image
|
16 |
+
# except ImportError:
|
17 |
+
# st.error("Failed to import TensorFlow. Please make sure it's installed correctly.")
|
18 |
+
|
19 |
+
# # Try importing PyTorch and Detectron2
|
20 |
+
# try:
|
21 |
+
# import torch
|
22 |
+
# import detectron2
|
23 |
+
# except ImportError:
|
24 |
+
# with st.spinner("Installing PyTorch and Detectron2..."):
|
25 |
+
# os.system("pip install torch torchvision")
|
26 |
+
# os.system("pip install 'git+https://github.com/facebookresearch/detectron2.git'")
|
27 |
+
|
28 |
+
# import torch
|
29 |
+
# import detectron2
|
30 |
+
|
31 |
+
|
32 |
+
# import streamlit as st
|
33 |
+
# import numpy as np
|
34 |
+
# import cv2
|
35 |
+
# import torch
|
36 |
+
# import os
|
37 |
+
# from PIL import Image
|
38 |
+
# from tensorflow.keras.models import load_model
|
39 |
+
# from tensorflow.keras.preprocessing import image
|
40 |
+
# from detectron2.engine import DefaultPredictor
|
41 |
+
# from detectron2.config import get_cfg
|
42 |
+
# from detectron2.utils.visualizer import Visualizer
|
43 |
+
# from detectron2.data import MetadataCatalog
|
44 |
+
|
45 |
+
# # Suppress warnings
|
46 |
+
# import warnings
|
47 |
+
# import tensorflow as tf
|
48 |
+
# warnings.filterwarnings("ignore")
|
49 |
+
# tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
|
50 |
+
|
51 |
+
# @st.cache_resource
|
52 |
+
# def load_models():
|
53 |
+
# model_name = load_model('name_model_inception.h5')
|
54 |
+
# model_quality = load_model('type_model_inception.h5')
|
55 |
+
# return model_name, model_quality
|
56 |
+
|
57 |
+
# model_name, model_quality = load_models()
|
58 |
+
|
59 |
+
# # Detectron2 setup
|
60 |
+
# @st.cache_resource
|
61 |
+
# def load_detectron_model(fruit_name):
|
62 |
+
# cfg = get_cfg()
|
63 |
+
# config_path = os.path.join(f"{fruit_name.lower()}_config.yaml")
|
64 |
+
# cfg.merge_from_file(config_path)
|
65 |
+
# model_path = os.path.join(f"{fruit_name}_model.pth")
|
66 |
+
# cfg.MODEL.WEIGHTS = model_path
|
67 |
+
# cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
|
68 |
+
# cfg.MODEL.DEVICE = 'cpu'
|
69 |
+
# predictor = DefaultPredictor(cfg)
|
70 |
+
# return predictor, cfg
|
71 |
+
|
72 |
+
# # Labels
|
73 |
+
# label_map_name = {
|
74 |
+
# 0: "Banana", 1: "Cucumber", 2: "Grape", 3: "Kaki", 4: "Papaya",
|
75 |
+
# 5: "Peach", 6: "Pear", 7: "Peeper", 8: "Strawberry", 9: "Watermelon",
|
76 |
+
# 10: "tomato"
|
77 |
+
# }
|
78 |
+
# label_map_quality = {0: "Good", 1: "Mild", 2: "Rotten"}
|
79 |
+
|
80 |
+
# def predict_fruit(img):
|
81 |
+
# # Preprocess image
|
82 |
+
# img = Image.fromarray(img.astype('uint8'), 'RGB')
|
83 |
+
# img = img.resize((224, 224))
|
84 |
+
# x = image.img_to_array(img)
|
85 |
+
# x = np.expand_dims(x, axis=0)
|
86 |
+
# x = x / 255.0
|
87 |
+
|
88 |
+
# # Predict
|
89 |
+
# pred_name = model_name.predict(x)
|
90 |
+
# pred_quality = model_quality.predict(x)
|
91 |
+
|
92 |
+
# predicted_name = label_map_name[np.argmax(pred_name, axis=1)[0]]
|
93 |
+
# predicted_quality = label_map_quality[np.argmax(pred_quality, axis=1)[0]]
|
94 |
|
95 |
+
# return predicted_name, predicted_quality, img
|
96 |
+
|
97 |
+
# def main():
|
98 |
+
# st.title("Automated Fruits Monitoring System")
|
99 |
+
# st.write("Upload an image of a fruit to detect its type, quality, and potential damage.")
|
100 |
+
|
101 |
+
# uploaded_file = st.file_uploader("Choose a fruit image...", type=["jpg", "jpeg", "png"])
|
102 |
+
|
103 |
+
# if uploaded_file is not None:
|
104 |
+
# image = Image.open(uploaded_file)
|
105 |
+
# st.image(image, caption="Uploaded Image", use_column_width=True)
|
106 |
+
|
107 |
+
# if st.button("Analyze"):
|
108 |
+
# predicted_name, predicted_quality, img = predict_fruit(np.array(image))
|
109 |
+
|
110 |
+
# st.write(f"Fruits Type Detection: {predicted_name}")
|
111 |
+
# st.write(f"Fruits Quality Classification: {predicted_quality}")
|
112 |
+
|
113 |
+
# if predicted_name.lower() in ["kaki", "tomato", "strawberry", "peeper", "pear", "peach", "papaya", "watermelon", "grape", "banana", "cucumber"] and predicted_quality in ["Mild", "Rotten"]:
|
114 |
+
# st.write("Segmentation of Defective Region:")
|
115 |
+
# try:
|
116 |
+
# predictor, cfg = load_detectron_model(predicted_name)
|
117 |
+
# outputs = predictor(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))
|
118 |
+
# v = Visualizer(np.array(img), MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=0.8)
|
119 |
+
# out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
|
120 |
+
# st.image(out.get_image(), caption="Damage Detection Result", use_column_width=True)
|
121 |
+
# except Exception as e:
|
122 |
+
# st.error(f"Error in damage detection: {str(e)}")
|
123 |
+
# else:
|
124 |
+
# st.write("No damage detection performed for this fruit or quality level.")
|
125 |
+
|
126 |
+
# if __name__ == "__main__":
|
127 |
+
# main()
|
128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
|
|
|
|
|
130 |
|
131 |
|
132 |
import streamlit as st
|
133 |
import numpy as np
|
134 |
import cv2
|
135 |
+
import warnings
|
136 |
import os
|
137 |
+
from pathlib import Path
|
138 |
from PIL import Image
|
139 |
+
import tensorflow as tf
|
140 |
from tensorflow.keras.models import load_model
|
141 |
from tensorflow.keras.preprocessing import image
|
142 |
from detectron2.engine import DefaultPredictor
|
|
|
145 |
from detectron2.data import MetadataCatalog
|
146 |
|
147 |
# Suppress warnings
|
|
|
|
|
148 |
warnings.filterwarnings("ignore")
|
149 |
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
|
150 |
|
151 |
+
# Configuration
|
152 |
+
MODEL_CONFIG = {
|
153 |
+
'name_model': 'name_model_inception.h5',
|
154 |
+
'quality_model': 'type_model_inception.h5',
|
155 |
+
'input_size': (224, 224),
|
156 |
+
'score_threshold': 0.5
|
157 |
+
}
|
158 |
|
159 |
+
LABEL_MAPS = {
|
160 |
+
'name': {
|
161 |
+
0: "Banana", 1: "Cucumber", 2: "Grape", 3: "Kaki", 4: "Papaya",
|
162 |
+
5: "Peach", 6: "Pear", 7: "Peeper", 8: "Strawberry", 9: "Watermelon",
|
163 |
+
10: "tomato"
|
164 |
+
},
|
165 |
+
'quality': {0: "Good", 1: "Mild", 2: "Rotten"}
|
166 |
+
}
|
167 |
|
|
|
168 |
@st.cache_resource
|
169 |
+
def load_classification_models():
|
170 |
+
"""Load and cache the classification models."""
|
171 |
+
try:
|
172 |
+
model_name = load_model(MODEL_CONFIG['name_model'])
|
173 |
+
model_quality = load_model(MODEL_CONFIG['quality_model'])
|
174 |
+
return model_name, model_quality
|
175 |
+
except Exception as e:
|
176 |
+
st.error(f"Error loading classification models: {str(e)}")
|
177 |
+
return None, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
178 |
|
179 |
+
@st.cache_resource
|
180 |
+
def load_detectron_model(fruit_name: str):
|
181 |
+
"""Load and cache the Detectron2 model for damage detection."""
|
182 |
+
try:
|
183 |
+
cfg = get_cfg()
|
184 |
+
config_path = Path(f"{fruit_name.lower()}_config.yaml")
|
185 |
+
model_path = Path(f"{fruit_name}_model.pth")
|
186 |
+
|
187 |
+
if not config_path.exists() or not model_path.exists():
|
188 |
+
raise FileNotFoundError(f"Model files not found for {fruit_name}")
|
189 |
+
|
190 |
+
cfg.merge_from_file(str(config_path))
|
191 |
+
cfg.MODEL.WEIGHTS = str(model_path)
|
192 |
+
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = MODEL_CONFIG['score_threshold']
|
193 |
+
cfg.MODEL.DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
194 |
+
|
195 |
+
return DefaultPredictor(cfg), cfg
|
196 |
+
except Exception as e:
|
197 |
+
st.error(f"Error loading Detectron2 model: {str(e)}")
|
198 |
+
return None, None
|
199 |
|
200 |
+
def preprocess_image(img: np.ndarray) -> tuple:
|
201 |
+
"""Preprocess the input image for model prediction."""
|
202 |
+
try:
|
203 |
+
# Convert to PIL Image if necessary
|
204 |
+
if isinstance(img, np.ndarray):
|
205 |
+
img = Image.fromarray(img.astype('uint8'), 'RGB')
|
206 |
+
|
207 |
+
# Resize and prepare for model input
|
208 |
+
img_resized = img.resize(MODEL_CONFIG['input_size'])
|
209 |
+
img_array = image.img_to_array(img_resized)
|
210 |
+
img_expanded = np.expand_dims(img_array, axis=0)
|
211 |
+
img_normalized = img_expanded / 255.0
|
212 |
+
|
213 |
+
return img_normalized, img_resized
|
214 |
+
except Exception as e:
|
215 |
+
st.error(f"Error preprocessing image: {str(e)}")
|
216 |
+
return None, None
|
217 |
|
218 |
+
def predict_fruit(img: np.ndarray) -> tuple:
|
219 |
+
"""Predict fruit type and quality."""
|
220 |
+
model_name, model_quality = load_classification_models()
|
221 |
+
if model_name is None or model_quality is None:
|
222 |
+
return None, None, None
|
223 |
+
|
224 |
+
img_normalized, img_resized = preprocess_image(img)
|
225 |
+
if img_normalized is None:
|
226 |
+
return None, None, None
|
227 |
+
|
228 |
+
try:
|
229 |
+
# Make predictions
|
230 |
+
pred_name = model_name.predict(img_normalized)
|
231 |
+
pred_quality = model_quality.predict(img_normalized)
|
232 |
+
|
233 |
+
# Get predicted labels
|
234 |
+
predicted_name = LABEL_MAPS['name'][np.argmax(pred_name, axis=1)[0]]
|
235 |
+
predicted_quality = LABEL_MAPS['quality'][np.argmax(pred_quality, axis=1)[0]]
|
236 |
+
|
237 |
+
return predicted_name, predicted_quality, img_resized
|
238 |
+
except Exception as e:
|
239 |
+
st.error(f"Error making predictions: {str(e)}")
|
240 |
+
return None, None, None
|
241 |
|
242 |
+
def detect_damage(img: Image, fruit_name: str) -> np.ndarray:
|
243 |
+
"""Detect and visualize damage in the fruit image."""
|
244 |
+
predictor, cfg = load_detectron_model(fruit_name)
|
245 |
+
if predictor is None or cfg is None:
|
246 |
+
return None
|
247 |
+
|
248 |
+
try:
|
249 |
+
outputs = predictor(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))
|
250 |
+
v = Visualizer(np.array(img), MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=0.8)
|
251 |
+
out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
|
252 |
+
return out.get_image()
|
253 |
+
except Exception as e:
|
254 |
+
st.error(f"Error in damage detection: {str(e)}")
|
255 |
+
return None
|
256 |
|
257 |
def main():
|
258 |
+
st.set_page_config(page_title="Fruit Quality Analysis", layout="wide")
|
259 |
+
|
260 |
+
st.title("Automated Fruits Monitoring System")
|
261 |
st.write("Upload an image of a fruit to detect its type, quality, and potential damage.")
|
262 |
+
|
263 |
uploaded_file = st.file_uploader("Choose a fruit image...", type=["jpg", "jpeg", "png"])
|
264 |
+
|
265 |
if uploaded_file is not None:
|
266 |
+
# Create two columns for layout
|
267 |
+
col1, col2 = st.columns(2)
|
268 |
+
|
269 |
+
# Display uploaded image
|
270 |
image = Image.open(uploaded_file)
|
271 |
+
col1.image(image, caption="Uploaded Image", use_column_width=True)
|
272 |
+
|
273 |
+
if col1.button("Analyze"):
|
274 |
+
with st.spinner("Analyzing image..."):
|
275 |
+
predicted_name, predicted_quality, img_resized = predict_fruit(np.array(image))
|
276 |
+
|
277 |
+
if predicted_name and predicted_quality:
|
278 |
+
# Display results
|
279 |
+
col2.markdown("### Analysis Results")
|
280 |
+
col2.markdown(f"**Fruit Type:** {predicted_name}")
|
281 |
+
col2.markdown(f"**Quality:** {predicted_quality}")
|
282 |
+
|
283 |
+
# Check if damage detection is needed
|
284 |
+
if (predicted_name.lower() in LABEL_MAPS['name'].values() and
|
285 |
+
predicted_quality in ["Mild", "Rotten"]):
|
286 |
+
|
287 |
+
col2.markdown("### Damage Detection")
|
288 |
+
damage_image = detect_damage(img_resized, predicted_name)
|
289 |
+
|
290 |
+
if damage_image is not None:
|
291 |
+
col2.image(damage_image, caption="Detected Damage Regions",
|
292 |
+
use_column_width=True)
|
293 |
+
|
294 |
+
# Add download button for the damage detection result
|
295 |
+
col2.download_button(
|
296 |
+
label="Download Analysis Result",
|
297 |
+
data=cv2.imencode('.png', damage_image)[1].tobytes(),
|
298 |
+
file_name=f"{predicted_name}_damage_analysis.png",
|
299 |
+
mime="image/png"
|
300 |
+
)
|
301 |
|
302 |
if __name__ == "__main__":
|
303 |
main()
|