MuhammmadRizwanRizwan commited on
Commit
8773d3c
·
verified ·
1 Parent(s): c8db654

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +151 -44
app.py CHANGED
@@ -16,56 +16,163 @@
16
 
17
 
18
 
19
- import streamlit as st
20
- from tensorflow.keras.models import load_model
21
- from tensorflow.keras.preprocessing import image
22
- import numpy as np
23
- from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- # Load the pre-trained models
26
- @st.cache_resource
27
- def load_models():
28
- model1 = load_model('name_model_inception.h5') # Update with your Hugging Face model path
29
- model2 = load_model('type_model_inception.h5') # Update with your Hugging Face model path
30
- return model1, model2
31
 
32
- model1, model2 = load_models()
 
 
33
 
34
- # Label mappings
35
- label_map1 = {
36
- 0: "Banana", 1: "Cucumber", 2: "Grape", 3: "Kaki", 4: "Papaya",
37
- 5: "Peach", 6: "Pear", 7: "Pepper", 8: "Strawberry", 9: "Watermelon", 10: "Tomato"
38
- }
39
 
40
- label_map2 = {
41
- 0: "Good", 1: "Mild", 2: "Rotten"
42
- }
43
 
44
- # Streamlit app layout
45
- st.title("Fruit Classifier")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  # Upload image
48
- uploaded_file = st.file_uploader("Choose an image of a fruit", type=["jpg", "jpeg", "png"])
49
 
50
  if uploaded_file is not None:
51
- # Display the uploaded image
52
- img = Image.open(uploaded_file)
53
- st.image(img, caption="Uploaded Image", use_column_width=True)
54
-
55
- # Preprocess the image
56
- img = img.resize((224, 224)) # Resize image to match the model input
57
- img_array = image.img_to_array(img)
58
- img_array = np.expand_dims(img_array, axis=0)
59
- img_array = img_array / 255.0 # Normalize the image
60
-
61
- # Make predictions
62
- pred1 = model1.predict(img_array)
63
- pred2 = model2.predict(img_array)
64
-
65
- predicted_class1 = np.argmax(pred1, axis=1)
66
- predicted_class2 = np.argmax(pred2, axis=1)
67
-
68
- # Display results
69
- st.write(f"**Type Detection**: {label_map1[predicted_class1[0]]}")
70
- st.write(f"**Condition Detection**: {label_map2[predicted_class2[0]]}")
71
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
 
18
 
19
+ # import streamlit as st
20
+ # from tensorflow.keras.models import load_model
21
+ # from tensorflow.keras.preprocessing import image
22
+ # import numpy as np
23
+ # from PIL import Image
24
+
25
+ # # Load the pre-trained models
26
+ # @st.cache_resource
27
+ # def load_models():
28
+ # model1 = load_model('name_model_inception.h5') # Update with your Hugging Face model path
29
+ # model2 = load_model('type_model_inception.h5') # Update with your Hugging Face model path
30
+ # return model1, model2
31
+
32
+ # model1, model2 = load_models()
33
+
34
+ # # Label mappings
35
+ # label_map1 = {
36
+ # 0: "Banana", 1: "Cucumber", 2: "Grape", 3: "Kaki", 4: "Papaya",
37
+ # 5: "Peach", 6: "Pear", 7: "Pepper", 8: "Strawberry", 9: "Watermelon", 10: "Tomato"
38
+ # }
39
+
40
+ # label_map2 = {
41
+ # 0: "Good", 1: "Mild", 2: "Rotten"
42
+ # }
43
+
44
+ # # Streamlit app layout
45
+ # st.title("Fruit Classifier")
46
+
47
+ # # Upload image
48
+ # uploaded_file = st.file_uploader("Choose an image of a fruit", type=["jpg", "jpeg", "png"])
49
+
50
+ # if uploaded_file is not None:
51
+ # # Display the uploaded image
52
+ # img = Image.open(uploaded_file)
53
+ # st.image(img, caption="Uploaded Image", use_column_width=True)
54
+
55
+ # # Preprocess the image
56
+ # img = img.resize((224, 224)) # Resize image to match the model input
57
+ # img_array = image.img_to_array(img)
58
+ # img_array = np.expand_dims(img_array, axis=0)
59
+ # img_array = img_array / 255.0 # Normalize the image
60
+
61
+ # # Make predictions
62
+ # pred1 = model1.predict(img_array)
63
+ # pred2 = model2.predict(img_array)
64
 
65
+ # predicted_class1 = np.argmax(pred1, axis=1)
66
+ # predicted_class2 = np.argmax(pred2, axis=1)
 
 
 
 
67
 
68
+ # # Display results
69
+ # st.write(f"**Type Detection**: {label_map1[predicted_class1[0]]}")
70
+ # st.write(f"**Condition Detection**: {label_map2[predicted_class2[0]]}")
71
 
 
 
 
 
 
72
 
 
 
 
73
 
74
+ import streamlit as st
75
+ import numpy as np
76
+ import cv2
77
+ import warnings
78
+
79
+ # Suppress warnings
80
+ warnings.filterwarnings("ignore", category=FutureWarning)
81
+ warnings.filterwarnings("ignore", category=UserWarning)
82
+
83
+ # Try importing TensorFlow
84
+ try:
85
+ from tensorflow.keras.models import load_model
86
+ from tensorflow.keras.preprocessing import image
87
+ except ImportError:
88
+ st.error("Failed to import TensorFlow. Please make sure it's installed correctly.")
89
+
90
+ # Try importing PyTorch and Detectron2
91
+ try:
92
+ import torch
93
+ from detectron2.engine import DefaultPredictor
94
+ from detectron2.config import get_cfg
95
+ from detectron2.utils.visualizer import Visualizer
96
+ from detectron2.data import MetadataCatalog
97
+ except ImportError:
98
+ st.error("Failed to import PyTorch or Detectron2. Please make sure they're installed correctly.")
99
+
100
+ # Load the trained models
101
+ try:
102
+ model_path_name = 'name_model_inception.h5'
103
+ model_path_quality = 'type_model_inception.h5'
104
+ detectron_config_path = 'watermelon.yaml'
105
+ detectron_weights_path = 'Watermelon_model.pth'
106
+
107
+ model_name = load_model(model_path_name)
108
+ model_quality = load_model(model_path_quality)
109
+ except Exception as e:
110
+ st.error(f"Failed to load models: {str(e)}")
111
+
112
+ # Streamlit app title
113
+ st.title("Watermelon Quality and Damage Detection")
114
 
115
  # Upload image
116
+ uploaded_file = st.file_uploader("Choose a watermelon image...", type=["jpg", "jpeg", "png"])
117
 
118
  if uploaded_file is not None:
119
+ try:
120
+ # Load the image
121
+ img = image.load_img(uploaded_file, target_size=(224, 224))
122
+ img_array = image.img_to_array(img)
123
+ img_array = np.expand_dims(img_array, axis=0)
124
+ img_array /= 255.0
125
+
126
+ # Display uploaded image
127
+ st.image(uploaded_file, caption="Uploaded Image", use_column_width=True)
128
+
129
+ # Predict watermelon name
130
+ pred_name = model_name.predict(img_array)
131
+ predicted_name = 'Watermelon'
132
+
133
+ # Predict watermelon quality
134
+ pred_quality = model_quality.predict(img_array)
135
+ predicted_class_quality = np.argmax(pred_quality, axis=1)
136
+
137
+ # Define labels for watermelon quality
138
+ label_map_quality = {
139
+ 0: "Good",
140
+ 1: "Mild",
141
+ 2: "Rotten"
142
+ }
143
+
144
+ predicted_quality = label_map_quality[predicted_class_quality[0]]
145
+
146
+ # Display predictions
147
+ st.write(f"Fruit Type Detection: {predicted_name}")
148
+ st.write(f"Fruit Quality Classification: {predicted_quality}")
149
+
150
+ # If the quality is 'Mild' or 'Rotten', pass the image to the mask detection model
151
+ if predicted_quality in ["Mild", "Rotten"]:
152
+ st.write("Passing the image to the mask detection model for damage detection...")
153
+
154
+ # Load the image again for the mask detection (Detectron2 requires the original image)
155
+ im = cv2.imdecode(np.fromstring(uploaded_file.read(), np.uint8), 1)
156
+
157
+ # Setup Detectron2 configuration for watermelon
158
+ cfg = get_cfg()
159
+ cfg.merge_from_file(detectron_config_path)
160
+ cfg.MODEL.WEIGHTS = detectron_weights_path
161
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
162
+ cfg.MODEL.DEVICE = 'cpu' # Use CPU for inference
163
+
164
+ predictor = DefaultPredictor(cfg)
165
+ predictor.model.load_state_dict(torch.load(detectron_weights_path, map_location=torch.device('cpu')))
166
+
167
+ # Run prediction on the image
168
+ outputs = predictor(im)
169
+
170
+ # Visualize the predictions
171
+ v = Visualizer(im[:, :, ::-1], MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=0.8)
172
+ out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
173
+
174
+ # Display the output
175
+ st.image(out.get_image()[:, :, ::-1], caption="Detected Damage", use_column_width=True)
176
+
177
+ except Exception as e:
178
+ st.error(f"An error occurred during processing: {str(e)}")