streamlit_app / app.py
MuhammmadRizwanRizwan's picture
Update app.py
1912d57 verified
raw
history blame
11.4 kB
# import streamlit as st
# import numpy as np
# import cv2
# import warnings
# import os
# # Suppress warnings
# warnings.filterwarnings("ignore", category=FutureWarning)
# warnings.filterwarnings("ignore", category=UserWarning)
# # Try importing TensorFlow
# try:
# from tensorflow.keras.models import load_model
# from tensorflow.keras.preprocessing import image
# except ImportError:
# st.error("Failed to import TensorFlow. Please make sure it's installed correctly.")
# # Try importing PyTorch and Detectron2
# try:
# import torch
# import detectron2
# except ImportError:
# with st.spinner("Installing PyTorch and Detectron2..."):
# os.system("pip install torch torchvision")
# os.system("pip install 'git+https://github.com/facebookresearch/detectron2.git'")
# import torch
# import detectron2
# import streamlit as st
# import numpy as np
# import cv2
# import torch
# import os
# from PIL import Image
# from tensorflow.keras.models import load_model
# from tensorflow.keras.preprocessing import image
# from detectron2.engine import DefaultPredictor
# from detectron2.config import get_cfg
# from detectron2.utils.visualizer import Visualizer
# from detectron2.data import MetadataCatalog
# # Suppress warnings
# import warnings
# import tensorflow as tf
# warnings.filterwarnings("ignore")
# tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
# @st.cache_resource
# def load_models():
# model_name = load_model('name_model_inception.h5')
# model_quality = load_model('type_model_inception.h5')
# return model_name, model_quality
# model_name, model_quality = load_models()
# # Detectron2 setup
# @st.cache_resource
# def load_detectron_model(fruit_name):
# cfg = get_cfg()
# config_path = os.path.join(f"{fruit_name.lower()}_config.yaml")
# cfg.merge_from_file(config_path)
# model_path = os.path.join(f"{fruit_name}_model.pth")
# cfg.MODEL.WEIGHTS = model_path
# cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
# cfg.MODEL.DEVICE = 'cpu'
# predictor = DefaultPredictor(cfg)
# return predictor, cfg
# # Labels
# label_map_name = {
# 0: "Banana", 1: "Cucumber", 2: "Grape", 3: "Kaki", 4: "Papaya",
# 5: "Peach", 6: "Pear", 7: "Peeper", 8: "Strawberry", 9: "Watermelon",
# 10: "tomato"
# }
# label_map_quality = {0: "Good", 1: "Mild", 2: "Rotten"}
# def predict_fruit(img):
# # Preprocess image
# img = Image.fromarray(img.astype('uint8'), 'RGB')
# img = img.resize((224, 224))
# x = image.img_to_array(img)
# x = np.expand_dims(x, axis=0)
# x = x / 255.0
# # Predict
# pred_name = model_name.predict(x)
# pred_quality = model_quality.predict(x)
# predicted_name = label_map_name[np.argmax(pred_name, axis=1)[0]]
# predicted_quality = label_map_quality[np.argmax(pred_quality, axis=1)[0]]
# return predicted_name, predicted_quality, img
# def main():
# st.title("Automated Fruits Monitoring System")
# st.write("Upload an image of a fruit to detect its type, quality, and potential damage.")
# uploaded_file = st.file_uploader("Choose a fruit image...", type=["jpg", "jpeg", "png"])
# if uploaded_file is not None:
# image = Image.open(uploaded_file)
# st.image(image, caption="Uploaded Image", use_column_width=True)
# if st.button("Analyze"):
# predicted_name, predicted_quality, img = predict_fruit(np.array(image))
# st.write(f"Fruits Type Detection: {predicted_name}")
# st.write(f"Fruits Quality Classification: {predicted_quality}")
# if predicted_name.lower() in ["kaki", "tomato", "strawberry", "peeper", "pear", "peach", "papaya", "watermelon", "grape", "banana", "cucumber"] and predicted_quality in ["Mild", "Rotten"]:
# st.write("Segmentation of Defective Region:")
# try:
# predictor, cfg = load_detectron_model(predicted_name)
# outputs = predictor(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))
# v = Visualizer(np.array(img), MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=0.8)
# out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
# st.image(out.get_image(), caption="Damage Detection Result", use_column_width=True)
# except Exception as e:
# st.error(f"Error in damage detection: {str(e)}")
# else:
# st.write("No damage detection performed for this fruit or quality level.")
# if __name__ == "__main__":
# main()
import streamlit as st
import numpy as np
import cv2
import warnings
import os
from pathlib import Path
from PIL import Image
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog
# Suppress warnings
warnings.filterwarnings("ignore")
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
# Configuration
MODEL_CONFIG = {
'name_model': 'name_model_inception.h5',
'quality_model': 'type_model_inception.h5',
'input_size': (224, 224),
'score_threshold': 0.5
}
LABEL_MAPS = {
'name': {
0: "Banana", 1: "Cucumber", 2: "Grape", 3: "Kaki", 4: "Papaya",
5: "Peach", 6: "Pear", 7: "Peeper", 8: "Strawberry", 9: "Watermelon",
10: "tomato"
},
'quality': {0: "Good", 1: "Mild", 2: "Rotten"}
}
@st.cache_resource
def load_classification_models():
"""Load and cache the classification models."""
try:
model_name = load_model(MODEL_CONFIG['name_model'])
model_quality = load_model(MODEL_CONFIG['quality_model'])
return model_name, model_quality
except Exception as e:
st.error(f"Error loading classification models: {str(e)}")
return None, None
@st.cache_resource
def load_detectron_model(fruit_name: str):
"""Load and cache the Detectron2 model for damage detection."""
try:
cfg = get_cfg()
config_path = Path(f"{fruit_name.lower()}_config.yaml")
model_path = Path(f"{fruit_name}_model.pth")
if not config_path.exists() or not model_path.exists():
raise FileNotFoundError(f"Model files not found for {fruit_name}")
cfg.merge_from_file(str(config_path))
cfg.MODEL.WEIGHTS = str(model_path)
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = MODEL_CONFIG['score_threshold']
cfg.MODEL.DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
return DefaultPredictor(cfg), cfg
except Exception as e:
st.error(f"Error loading Detectron2 model: {str(e)}")
return None, None
def preprocess_image(img: np.ndarray) -> tuple:
"""Preprocess the input image for model prediction."""
try:
# Convert to PIL Image if necessary
if isinstance(img, np.ndarray):
img = Image.fromarray(img.astype('uint8'), 'RGB')
# Resize and prepare for model input
img_resized = img.resize(MODEL_CONFIG['input_size'])
img_array = image.img_to_array(img_resized)
img_expanded = np.expand_dims(img_array, axis=0)
img_normalized = img_expanded / 255.0
return img_normalized, img_resized
except Exception as e:
st.error(f"Error preprocessing image: {str(e)}")
return None, None
def predict_fruit(img: np.ndarray) -> tuple:
"""Predict fruit type and quality."""
model_name, model_quality = load_classification_models()
if model_name is None or model_quality is None:
return None, None, None
img_normalized, img_resized = preprocess_image(img)
if img_normalized is None:
return None, None, None
try:
# Make predictions
pred_name = model_name.predict(img_normalized)
pred_quality = model_quality.predict(img_normalized)
# Get predicted labels
predicted_name = LABEL_MAPS['name'][np.argmax(pred_name, axis=1)[0]]
predicted_quality = LABEL_MAPS['quality'][np.argmax(pred_quality, axis=1)[0]]
return predicted_name, predicted_quality, img_resized
except Exception as e:
st.error(f"Error making predictions: {str(e)}")
return None, None, None
def detect_damage(img: Image, fruit_name: str) -> np.ndarray:
"""Detect and visualize damage in the fruit image."""
predictor, cfg = load_detectron_model(fruit_name)
if predictor is None or cfg is None:
return None
try:
outputs = predictor(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))
v = Visualizer(np.array(img), MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=0.8)
out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
return out.get_image()
except Exception as e:
st.error(f"Error in damage detection: {str(e)}")
return None
def main():
st.set_page_config(page_title="Fruit Quality Analysis", layout="wide")
st.title("Automated Fruits Monitoring System")
st.write("Upload an image of a fruit to detect its type, quality, and potential damage.")
uploaded_file = st.file_uploader("Choose a fruit image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
# Create two columns for layout
col1, col2 = st.columns(2)
# Display uploaded image
image = Image.open(uploaded_file)
col1.image(image, caption="Uploaded Image", use_column_width=True)
if col1.button("Analyze"):
with st.spinner("Analyzing image..."):
predicted_name, predicted_quality, img_resized = predict_fruit(np.array(image))
if predicted_name and predicted_quality:
# Display results
col2.markdown("### Analysis Results")
col2.markdown(f"**Fruit Type:** {predicted_name}")
col2.markdown(f"**Quality:** {predicted_quality}")
# Check if damage detection is needed
if (predicted_name.lower() in LABEL_MAPS['name'].values() and
predicted_quality in ["Mild", "Rotten"]):
col2.markdown("### Damage Detection")
damage_image = detect_damage(img_resized, predicted_name)
if damage_image is not None:
col2.image(damage_image, caption="Detected Damage Regions",
use_column_width=True)
# Add download button for the damage detection result
col2.download_button(
label="Download Analysis Result",
data=cv2.imencode('.png', damage_image)[1].tobytes(),
file_name=f"{predicted_name}_damage_analysis.png",
mime="image/png"
)
if __name__ == "__main__":
main()