MuhammmadRizwanRizwan commited on
Commit
bd891ff
·
verified ·
1 Parent(s): a411d5f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +149 -13
app.py CHANGED
@@ -338,6 +338,133 @@
338
 
339
 
340
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341
 
342
 
343
 
@@ -390,24 +517,26 @@ def load_models():
390
 
391
  model_name, model_quality = load_models()
392
 
393
- # Setup Detectron2 configuration for watermelon
394
  @st.cache_resource
395
- def load_detectron_model():
396
  cfg = get_cfg()
397
- cfg.merge_from_file("watermelon.yaml")
398
- cfg.MODEL.WEIGHTS = "Watermelon_model.pth"
 
 
 
 
399
  cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
400
  cfg.MODEL.DEVICE = 'cpu' # Use CPU for inference
401
  predictor = DefaultPredictor(cfg)
402
  return predictor, cfg
403
 
404
- predictor, cfg = load_detectron_model()
405
-
406
  # Streamlit app title
407
- st.title("Watermelon Quality and Damage Detection")
408
 
409
  # Upload image
410
- uploaded_file = st.file_uploader("Choose a watermelon image...", type=["jpg", "jpeg", "png"])
411
 
412
  if uploaded_file is not None:
413
  try:
@@ -420,15 +549,15 @@ if uploaded_file is not None:
420
  # Display uploaded image
421
  st.image(uploaded_file, caption="Uploaded Image", use_column_width=True)
422
 
423
- # Predict watermelon name
424
  pred_name = model_name.predict(img_array)
425
- predicted_name = 'Watermelon'
426
 
427
- # Predict watermelon quality
428
  pred_quality = model_quality.predict(img_array)
429
  predicted_class_quality = np.argmax(pred_quality, axis=1)
430
 
431
- # Define labels for watermelon quality
432
  label_map_quality = {
433
  0: "Good",
434
  1: "Mild",
@@ -441,13 +570,19 @@ if uploaded_file is not None:
441
  st.write(f"Fruit Type Detection: {predicted_name}")
442
  st.write(f"Fruit Quality Classification: {predicted_quality}")
443
 
444
- # If the quality is 'Mild' or 'Rotten', pass the image to the mask detection model
445
  if predicted_quality in ["Mild", "Rotten"]:
446
  st.write("Passing the image to the mask detection model for damage detection...")
447
 
448
  # Load the image again for the mask detection (Detectron2 requires the original image)
449
  im = cv2.imdecode(np.fromstring(uploaded_file.read(), np.uint8), 1)
450
 
 
 
 
 
 
 
451
  # Run prediction on the image
452
  outputs = predictor(im)
453
 
@@ -460,3 +595,4 @@ if uploaded_file is not None:
460
 
461
  except Exception as e:
462
  st.error(f"An error occurred during processing: {str(e)}")
 
 
338
 
339
 
340
 
341
+ # ///////////////////////////////////Working
342
+
343
+
344
+ # import streamlit as st
345
+ # import numpy as np
346
+ # import cv2
347
+ # import warnings
348
+ # import os
349
+
350
+ # # Suppress warnings
351
+ # warnings.filterwarnings("ignore", category=FutureWarning)
352
+ # warnings.filterwarnings("ignore", category=UserWarning)
353
+
354
+ # # Try importing TensorFlow
355
+ # try:
356
+ # from tensorflow.keras.models import load_model
357
+ # from tensorflow.keras.preprocessing import image
358
+ # except ImportError:
359
+ # st.error("Failed to import TensorFlow. Please make sure it's installed correctly.")
360
+
361
+ # # Try importing PyTorch and Detectron2
362
+ # try:
363
+ # import torch
364
+ # import detectron2
365
+ # except ImportError:
366
+ # with st.spinner("Installing PyTorch and Detectron2..."):
367
+ # os.system("pip install torch torchvision")
368
+ # os.system("pip install 'git+https://github.com/facebookresearch/detectron2.git'")
369
+
370
+ # import torch
371
+ # import detectron2
372
+
373
+ # from detectron2.engine import DefaultPredictor
374
+ # from detectron2.config import get_cfg
375
+ # from detectron2.utils.visualizer import Visualizer
376
+ # from detectron2.data import MetadataCatalog
377
+
378
+ # # Load the trained models
379
+ # @st.cache_resource
380
+ # def load_models():
381
+ # try:
382
+ # model_path_name = 'name_model_inception.h5'
383
+ # model_path_quality = 'type_model_inception.h5'
384
+ # model_name = load_model(model_path_name)
385
+ # model_quality = load_model(model_path_quality)
386
+ # return model_name, model_quality
387
+ # except Exception as e:
388
+ # st.error(f"Failed to load models: {str(e)}")
389
+ # return None, None
390
+
391
+ # model_name, model_quality = load_models()
392
+
393
+ # # Setup Detectron2 configuration for watermelon
394
+ # @st.cache_resource
395
+ # def load_detectron_model():
396
+ # cfg = get_cfg()
397
+ # cfg.merge_from_file("watermelon.yaml")
398
+ # cfg.MODEL.WEIGHTS = "Watermelon_model.pth"
399
+ # cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
400
+ # cfg.MODEL.DEVICE = 'cpu' # Use CPU for inference
401
+ # predictor = DefaultPredictor(cfg)
402
+ # return predictor, cfg
403
+
404
+ # predictor, cfg = load_detectron_model()
405
+
406
+ # # Streamlit app title
407
+ # st.title("Watermelon Quality and Damage Detection")
408
+
409
+ # # Upload image
410
+ # uploaded_file = st.file_uploader("Choose a watermelon image...", type=["jpg", "jpeg", "png"])
411
+
412
+ # if uploaded_file is not None:
413
+ # try:
414
+ # # Load the image
415
+ # img = image.load_img(uploaded_file, target_size=(224, 224))
416
+ # img_array = image.img_to_array(img)
417
+ # img_array = np.expand_dims(img_array, axis=0)
418
+ # img_array /= 255.0
419
+
420
+ # # Display uploaded image
421
+ # st.image(uploaded_file, caption="Uploaded Image", use_column_width=True)
422
+
423
+ # # Predict watermelon name
424
+ # pred_name = model_name.predict(img_array)
425
+ # predicted_name = 'Watermelon'
426
+
427
+ # # Predict watermelon quality
428
+ # pred_quality = model_quality.predict(img_array)
429
+ # predicted_class_quality = np.argmax(pred_quality, axis=1)
430
+
431
+ # # Define labels for watermelon quality
432
+ # label_map_quality = {
433
+ # 0: "Good",
434
+ # 1: "Mild",
435
+ # 2: "Rotten"
436
+ # }
437
+
438
+ # predicted_quality = label_map_quality[predicted_class_quality[0]]
439
+
440
+ # # Display predictions
441
+ # st.write(f"Fruit Type Detection: {predicted_name}")
442
+ # st.write(f"Fruit Quality Classification: {predicted_quality}")
443
+
444
+ # # If the quality is 'Mild' or 'Rotten', pass the image to the mask detection model
445
+ # if predicted_quality in ["Mild", "Rotten"]:
446
+ # st.write("Passing the image to the mask detection model for damage detection...")
447
+
448
+ # # Load the image again for the mask detection (Detectron2 requires the original image)
449
+ # im = cv2.imdecode(np.fromstring(uploaded_file.read(), np.uint8), 1)
450
+
451
+ # # Run prediction on the image
452
+ # outputs = predictor(im)
453
+
454
+ # # Visualize the predictions
455
+ # v = Visualizer(im[:, :, ::-1], MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=0.8)
456
+ # out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
457
+
458
+ # # Display the output
459
+ # st.image(out.get_image()[:, :, ::-1], caption="Detected Damage", use_column_width=True)
460
+
461
+ # except Exception as e:
462
+ # st.error(f"An error occurred during processing: {str(e)}")
463
+
464
+
465
+
466
+
467
+
468
 
469
 
470
 
 
517
 
518
  model_name, model_quality = load_models()
519
 
520
+ # Setup Detectron2 configuration for watermelon and banana
521
  @st.cache_resource
522
+ def load_detectron_model(model_type):
523
  cfg = get_cfg()
524
+ if model_type == "watermelon":
525
+ cfg.merge_from_file("watermelon.yaml")
526
+ cfg.MODEL.WEIGHTS = "Watermelon_model.pth"
527
+ elif model_type == "tomato":
528
+ cfg.merge_from_file("tomato.yaml")
529
+ cfg.MODEL.WEIGHTS = "tomato_model.pth"
530
  cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
531
  cfg.MODEL.DEVICE = 'cpu' # Use CPU for inference
532
  predictor = DefaultPredictor(cfg)
533
  return predictor, cfg
534
 
 
 
535
  # Streamlit app title
536
+ st.title("Fruit Quality and Damage Detection")
537
 
538
  # Upload image
539
+ uploaded_file = st.file_uploader("Choose a fruit image...", type=["jpg", "jpeg", "png"])
540
 
541
  if uploaded_file is not None:
542
  try:
 
549
  # Display uploaded image
550
  st.image(uploaded_file, caption="Uploaded Image", use_column_width=True)
551
 
552
+ # Predict fruit name
553
  pred_name = model_name.predict(img_array)
554
+ predicted_name = 'Banana' # Example: Modify based on the actual model's output
555
 
556
+ # Predict fruit quality
557
  pred_quality = model_quality.predict(img_array)
558
  predicted_class_quality = np.argmax(pred_quality, axis=1)
559
 
560
+ # Define labels for fruit quality
561
  label_map_quality = {
562
  0: "Good",
563
  1: "Mild",
 
570
  st.write(f"Fruit Type Detection: {predicted_name}")
571
  st.write(f"Fruit Quality Classification: {predicted_quality}")
572
 
573
+ # If the quality is 'Mild' or 'Rotten', pass the image to the corresponding mask detection model
574
  if predicted_quality in ["Mild", "Rotten"]:
575
  st.write("Passing the image to the mask detection model for damage detection...")
576
 
577
  # Load the image again for the mask detection (Detectron2 requires the original image)
578
  im = cv2.imdecode(np.fromstring(uploaded_file.read(), np.uint8), 1)
579
 
580
+ # Check if the predicted fruit is Banana or Watermelon and load the correct model
581
+ if predicted_name == "tomato":
582
+ predictor, cfg = load_detectron_model("tomato")
583
+ elif predicted_name == "Watermelon":
584
+ predictor, cfg = load_detectron_model("watermelon")
585
+
586
  # Run prediction on the image
587
  outputs = predictor(im)
588
 
 
595
 
596
  except Exception as e:
597
  st.error(f"An error occurred during processing: {str(e)}")
598
+