MuhammmadRizwanRizwan commited on
Commit
1523d77
·
verified ·
1 Parent(s): 59d62bb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -26
app.py CHANGED
@@ -483,8 +483,7 @@ from detectron2.data import MetadataCatalog
483
 
484
 
485
 
486
-
487
- import gradio as gr
488
  import numpy as np
489
  import cv2
490
  import torch
@@ -502,15 +501,20 @@ import tensorflow as tf
502
  warnings.filterwarnings("ignore")
503
  tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
504
 
505
- # Load models
506
- model_name = load_model('name_model_inception.h5')
507
- model_quality = load_model('type_model_inception.h5')
 
 
 
 
508
 
509
  # Detectron2 setup
 
510
  def load_detectron_model(fruit_name):
511
  cfg = get_cfg()
512
- cfg.merge_from_file(f"{fruit_name.lower()}.yaml")
513
- cfg.MODEL.WEIGHTS = f"{fruit_name}_model.pth"
514
  cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
515
  cfg.MODEL.DEVICE = 'cpu'
516
  predictor = DefaultPredictor(cfg)
@@ -539,31 +543,35 @@ def predict_fruit(img):
539
  predicted_name = label_map_name[np.argmax(pred_name, axis=1)[0]]
540
  predicted_quality = label_map_quality[np.argmax(pred_quality, axis=1)[0]]
541
 
542
- result = f"Fruit Type: {predicted_name}\nFruit Quality: {predicted_quality}"
543
 
544
- # Damage detection for specific fruits
545
- if predicted_name.lower() in ["kaki", "tomato", "strawberry", "pepper", "pear", "peach", "papaya", "watermelon", "grape", "banana", "cucumber"] and predicted_quality in ["Mild", "Rotten"]:
546
- predictor, cfg = load_detectron_model(predicted_name)
547
- outputs = predictor(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))
548
- v = Visualizer(np.array(img), MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=0.8)
549
- out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
550
- result_image = out.get_image()
551
- else:
552
- result_image = np.array(img)
553
 
554
- return result, result_image
555
 
556
- iface = gr.Interface(
557
- fn=predict_fruit,
558
- inputs=gr.Image(),
559
- outputs=[gr.Textbox(), gr.Image()],
560
- title="Fruit Quality and Damage Detection",
561
- description="Upload an image of a fruit to detect its type, quality, and potential damage."
562
- )
563
 
564
- iface.launch()
 
565
 
 
 
566
 
 
 
 
 
 
 
 
 
 
567
 
 
 
568
 
569
 
 
483
 
484
 
485
 
486
+ import streamlit as st
 
487
  import numpy as np
488
  import cv2
489
  import torch
 
501
  warnings.filterwarnings("ignore")
502
  tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
503
 
504
+ @st.cache_resource
505
+ def load_models():
506
+ model_name = load_model('name_model_inception.h5')
507
+ model_quality = load_model('type_model_inception.h5')
508
+ return model_name, model_quality
509
+
510
+ model_name, model_quality = load_models()
511
 
512
  # Detectron2 setup
513
+ @st.cache_resource
514
  def load_detectron_model(fruit_name):
515
  cfg = get_cfg()
516
+ cfg.merge_from_file(f"{fruit_name.lower()}_config.yaml")
517
+ cfg.MODEL.WEIGHTS = f"{fruit_name.lower()}_model.pth"
518
  cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
519
  cfg.MODEL.DEVICE = 'cpu'
520
  predictor = DefaultPredictor(cfg)
 
543
  predicted_name = label_map_name[np.argmax(pred_name, axis=1)[0]]
544
  predicted_quality = label_map_quality[np.argmax(pred_quality, axis=1)[0]]
545
 
546
+ return predicted_name, predicted_quality, img
547
 
548
+ def main():
549
+ st.title("Fruit Quality and Damage Detection")
550
+ st.write("Upload an image of a fruit to detect its type, quality, and potential damage.")
 
 
 
 
 
 
551
 
552
+ uploaded_file = st.file_uploader("Choose a fruit image...", type=["jpg", "jpeg", "png"])
553
 
554
+ if uploaded_file is not None:
555
+ image = Image.open(uploaded_file)
556
+ st.image(image, caption="Uploaded Image", use_column_width=True)
 
 
 
 
557
 
558
+ if st.button("Analyze"):
559
+ predicted_name, predicted_quality, img = predict_fruit(np.array(image))
560
 
561
+ st.write(f"Fruit Type: {predicted_name}")
562
+ st.write(f"Fruit Quality: {predicted_quality}")
563
 
564
+ if predicted_name.lower() in ["kaki", "tomato", "strawberry", "pepper", "pear", "peach", "papaya", "watermelon", "grape", "banana", "cucumber"] and predicted_quality in ["Mild", "Rotten"]:
565
+ st.write("Detecting damage...")
566
+ predictor, cfg = load_detectron_model(predicted_name)
567
+ outputs = predictor(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))
568
+ v = Visualizer(np.array(img), MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=0.8)
569
+ out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
570
+ st.image(out.get_image(), caption="Damage Detection Result", use_column_width=True)
571
+ else:
572
+ st.write("No damage detection performed for this fruit or quality level.")
573
 
574
+ if __name__ == "__main__":
575
+ main()
576
 
577