Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -483,8 +483,7 @@ from detectron2.data import MetadataCatalog
|
|
483 |
|
484 |
|
485 |
|
486 |
-
|
487 |
-
import gradio as gr
|
488 |
import numpy as np
|
489 |
import cv2
|
490 |
import torch
|
@@ -502,15 +501,20 @@ import tensorflow as tf
|
|
502 |
warnings.filterwarnings("ignore")
|
503 |
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
|
504 |
|
505 |
-
|
506 |
-
|
507 |
-
|
|
|
|
|
|
|
|
|
508 |
|
509 |
# Detectron2 setup
|
|
|
510 |
def load_detectron_model(fruit_name):
|
511 |
cfg = get_cfg()
|
512 |
-
cfg.merge_from_file(f"{fruit_name.lower()}.yaml")
|
513 |
-
cfg.MODEL.WEIGHTS = f"{fruit_name}_model.pth"
|
514 |
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
|
515 |
cfg.MODEL.DEVICE = 'cpu'
|
516 |
predictor = DefaultPredictor(cfg)
|
@@ -539,31 +543,35 @@ def predict_fruit(img):
|
|
539 |
predicted_name = label_map_name[np.argmax(pred_name, axis=1)[0]]
|
540 |
predicted_quality = label_map_quality[np.argmax(pred_quality, axis=1)[0]]
|
541 |
|
542 |
-
|
543 |
|
544 |
-
|
545 |
-
|
546 |
-
|
547 |
-
outputs = predictor(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))
|
548 |
-
v = Visualizer(np.array(img), MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=0.8)
|
549 |
-
out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
|
550 |
-
result_image = out.get_image()
|
551 |
-
else:
|
552 |
-
result_image = np.array(img)
|
553 |
|
554 |
-
|
555 |
|
556 |
-
|
557 |
-
|
558 |
-
|
559 |
-
outputs=[gr.Textbox(), gr.Image()],
|
560 |
-
title="Fruit Quality and Damage Detection",
|
561 |
-
description="Upload an image of a fruit to detect its type, quality, and potential damage."
|
562 |
-
)
|
563 |
|
564 |
-
|
|
|
565 |
|
|
|
|
|
566 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
567 |
|
|
|
|
|
568 |
|
569 |
|
|
|
483 |
|
484 |
|
485 |
|
486 |
+
import streamlit as st
|
|
|
487 |
import numpy as np
|
488 |
import cv2
|
489 |
import torch
|
|
|
501 |
warnings.filterwarnings("ignore")
|
502 |
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
|
503 |
|
504 |
+
@st.cache_resource
|
505 |
+
def load_models():
|
506 |
+
model_name = load_model('name_model_inception.h5')
|
507 |
+
model_quality = load_model('type_model_inception.h5')
|
508 |
+
return model_name, model_quality
|
509 |
+
|
510 |
+
model_name, model_quality = load_models()
|
511 |
|
512 |
# Detectron2 setup
|
513 |
+
@st.cache_resource
|
514 |
def load_detectron_model(fruit_name):
|
515 |
cfg = get_cfg()
|
516 |
+
cfg.merge_from_file(f"{fruit_name.lower()}_config.yaml")
|
517 |
+
cfg.MODEL.WEIGHTS = f"{fruit_name.lower()}_model.pth"
|
518 |
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
|
519 |
cfg.MODEL.DEVICE = 'cpu'
|
520 |
predictor = DefaultPredictor(cfg)
|
|
|
543 |
predicted_name = label_map_name[np.argmax(pred_name, axis=1)[0]]
|
544 |
predicted_quality = label_map_quality[np.argmax(pred_quality, axis=1)[0]]
|
545 |
|
546 |
+
return predicted_name, predicted_quality, img
|
547 |
|
548 |
+
def main():
|
549 |
+
st.title("Fruit Quality and Damage Detection")
|
550 |
+
st.write("Upload an image of a fruit to detect its type, quality, and potential damage.")
|
|
|
|
|
|
|
|
|
|
|
|
|
551 |
|
552 |
+
uploaded_file = st.file_uploader("Choose a fruit image...", type=["jpg", "jpeg", "png"])
|
553 |
|
554 |
+
if uploaded_file is not None:
|
555 |
+
image = Image.open(uploaded_file)
|
556 |
+
st.image(image, caption="Uploaded Image", use_column_width=True)
|
|
|
|
|
|
|
|
|
557 |
|
558 |
+
if st.button("Analyze"):
|
559 |
+
predicted_name, predicted_quality, img = predict_fruit(np.array(image))
|
560 |
|
561 |
+
st.write(f"Fruit Type: {predicted_name}")
|
562 |
+
st.write(f"Fruit Quality: {predicted_quality}")
|
563 |
|
564 |
+
if predicted_name.lower() in ["kaki", "tomato", "strawberry", "pepper", "pear", "peach", "papaya", "watermelon", "grape", "banana", "cucumber"] and predicted_quality in ["Mild", "Rotten"]:
|
565 |
+
st.write("Detecting damage...")
|
566 |
+
predictor, cfg = load_detectron_model(predicted_name)
|
567 |
+
outputs = predictor(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))
|
568 |
+
v = Visualizer(np.array(img), MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=0.8)
|
569 |
+
out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
|
570 |
+
st.image(out.get_image(), caption="Damage Detection Result", use_column_width=True)
|
571 |
+
else:
|
572 |
+
st.write("No damage detection performed for this fruit or quality level.")
|
573 |
|
574 |
+
if __name__ == "__main__":
|
575 |
+
main()
|
576 |
|
577 |
|