MuhammmadRizwanRizwan commited on
Commit
3734d84
·
verified ·
1 Parent(s): 83eb0c6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +200 -100
app.py CHANGED
@@ -341,127 +341,227 @@
341
  # ///////////////////////////////////Working
342
 
343
 
344
- import streamlit as st
345
- import numpy as np
346
- import cv2
347
- import warnings
348
- import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349
 
350
- # Suppress warnings
351
- warnings.filterwarnings("ignore", category=FutureWarning)
352
- warnings.filterwarnings("ignore", category=UserWarning)
353
-
354
- # Try importing TensorFlow
355
- try:
356
- from tensorflow.keras.models import load_model
357
- from tensorflow.keras.preprocessing import image
358
- except ImportError:
359
- st.error("Failed to import TensorFlow. Please make sure it's installed correctly.")
360
-
361
- # Try importing PyTorch and Detectron2
362
- try:
363
- import torch
364
- import detectron2
365
- except ImportError:
366
- with st.spinner("Installing PyTorch and Detectron2..."):
367
- os.system("pip install torch torchvision")
368
- os.system("pip install 'git+https://github.com/facebookresearch/detectron2.git'")
369
-
370
- import torch
371
- import detectron2
372
 
373
- from detectron2.engine import DefaultPredictor
374
- from detectron2.config import get_cfg
375
- from detectron2.utils.visualizer import Visualizer
376
- from detectron2.data import MetadataCatalog
377
 
378
- # Load the trained models
379
- @st.cache_resource
380
- def load_models():
381
- try:
382
- model_path_name = 'name_model_inception.h5'
383
- model_path_quality = 'type_model_inception.h5'
384
- model_name = load_model(model_path_name)
385
- model_quality = load_model(model_path_quality)
386
- return model_name, model_quality
387
- except Exception as e:
388
- st.error(f"Failed to load models: {str(e)}")
389
- return None, None
390
-
391
- model_name, model_quality = load_models()
392
-
393
- # Setup Detectron2 configuration for watermelon
394
- @st.cache_resource
395
- def load_detectron_model():
396
- cfg = get_cfg()
397
- cfg.merge_from_file("watermelon.yaml")
398
- cfg.MODEL.WEIGHTS = "Watermelon_model.pth"
399
- cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
400
- cfg.MODEL.DEVICE = 'cpu' # Use CPU for inference
401
- predictor = DefaultPredictor(cfg)
402
- return predictor, cfg
403
 
404
- predictor, cfg = load_detectron_model()
405
 
406
- # Streamlit app title
407
- st.title("Watermelon Quality and Damage Detection")
408
 
409
- # Upload image
410
- uploaded_file = st.file_uploader("Choose a watermelon image...", type=["jpg", "jpeg", "png"])
411
 
412
- if uploaded_file is not None:
413
- try:
414
- # Load the image
415
- img = image.load_img(uploaded_file, target_size=(224, 224))
416
- img_array = image.img_to_array(img)
417
- img_array = np.expand_dims(img_array, axis=0)
418
- img_array /= 255.0
419
 
420
- # Display uploaded image
421
- st.image(uploaded_file, caption="Uploaded Image", use_column_width=True)
422
 
423
- # Predict watermelon name
424
- pred_name = model_name.predict(img_array)
425
- predicted_name = 'Watermelon'
426
 
427
- # Predict watermelon quality
428
- pred_quality = model_quality.predict(img_array)
429
- predicted_class_quality = np.argmax(pred_quality, axis=1)
430
 
431
- # Define labels for watermelon quality
432
- label_map_quality = {
433
- 0: "Good",
434
- 1: "Mild",
435
- 2: "Rotten"
436
- }
437
 
438
- predicted_quality = label_map_quality[predicted_class_quality[0]]
439
 
440
- # Display predictions
441
- st.write(f"Fruit Type Detection: {predicted_name}")
442
- st.write(f"Fruit Quality Classification: {predicted_quality}")
443
 
444
- # If the quality is 'Mild' or 'Rotten', pass the image to the mask detection model
445
- if predicted_quality in ["Mild", "Rotten"]:
446
- st.write("Passing the image to the mask detection model for damage detection...")
447
 
448
- # Load the image again for the mask detection (Detectron2 requires the original image)
449
- im = cv2.imdecode(np.fromstring(uploaded_file.read(), np.uint8), 1)
450
 
451
- # Run prediction on the image
452
- outputs = predictor(im)
453
 
454
- # Visualize the predictions
455
- v = Visualizer(im[:, :, ::-1], MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=0.8)
456
- out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
457
 
458
- # Display the output
459
- st.image(out.get_image()[:, :, ::-1], caption="Detected Damage", use_column_width=True)
460
 
461
- except Exception as e:
462
- st.error(f"An error occurred during processing: {str(e)}")
463
 
464
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
465
 
466
 
467
 
 
341
  # ///////////////////////////////////Working
342
 
343
 
344
+ # import streamlit as st
345
+ # import numpy as np
346
+ # import cv2
347
+ # import warnings
348
+ # import os
349
+
350
+ # # Suppress warnings
351
+ # warnings.filterwarnings("ignore", category=FutureWarning)
352
+ # warnings.filterwarnings("ignore", category=UserWarning)
353
+
354
+ # # Try importing TensorFlow
355
+ # try:
356
+ # from tensorflow.keras.models import load_model
357
+ # from tensorflow.keras.preprocessing import image
358
+ # except ImportError:
359
+ # st.error("Failed to import TensorFlow. Please make sure it's installed correctly.")
360
+
361
+ # # Try importing PyTorch and Detectron2
362
+ # try:
363
+ # import torch
364
+ # import detectron2
365
+ # except ImportError:
366
+ # with st.spinner("Installing PyTorch and Detectron2..."):
367
+ # os.system("pip install torch torchvision")
368
+ # os.system("pip install 'git+https://github.com/facebookresearch/detectron2.git'")
369
+
370
+ # import torch
371
+ # import detectron2
372
+
373
+ # from detectron2.engine import DefaultPredictor
374
+ # from detectron2.config import get_cfg
375
+ # from detectron2.utils.visualizer import Visualizer
376
+ # from detectron2.data import MetadataCatalog
377
+
378
+ # # Load the trained models
379
+ # @st.cache_resource
380
+ # def load_models():
381
+ # try:
382
+ # model_path_name = 'name_model_inception.h5'
383
+ # model_path_quality = 'type_model_inception.h5'
384
+ # model_name = load_model(model_path_name)
385
+ # model_quality = load_model(model_path_quality)
386
+ # return model_name, model_quality
387
+ # except Exception as e:
388
+ # st.error(f"Failed to load models: {str(e)}")
389
+ # return None, None
390
+
391
+ # model_name, model_quality = load_models()
392
+
393
+ # # Setup Detectron2 configuration for watermelon
394
+ # @st.cache_resource
395
+ # def load_detectron_model():
396
+ # cfg = get_cfg()
397
+ # cfg.merge_from_file("watermelon.yaml")
398
+ # cfg.MODEL.WEIGHTS = "Watermelon_model.pth"
399
+ # cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
400
+ # cfg.MODEL.DEVICE = 'cpu' # Use CPU for inference
401
+ # predictor = DefaultPredictor(cfg)
402
+ # return predictor, cfg
403
+
404
+ # predictor, cfg = load_detectron_model()
405
+
406
+ # # Streamlit app title
407
+ # st.title("Watermelon Quality and Damage Detection")
408
+
409
+ # # Upload image
410
+ # uploaded_file = st.file_uploader("Choose a watermelon image...", type=["jpg", "jpeg", "png"])
411
+
412
+ # if uploaded_file is not None:
413
+ # try:
414
+ # # Load the image
415
+ # img = image.load_img(uploaded_file, target_size=(224, 224))
416
+ # img_array = image.img_to_array(img)
417
+ # img_array = np.expand_dims(img_array, axis=0)
418
+ # img_array /= 255.0
419
+
420
+ # # Display uploaded image
421
+ # st.image(uploaded_file, caption="Uploaded Image", use_column_width=True)
422
+
423
+ # # Predict watermelon name
424
+ # pred_name = model_name.predict(img_array)
425
+ # predicted_name = 'Watermelon'
426
+
427
+ # # Predict watermelon quality
428
+ # pred_quality = model_quality.predict(img_array)
429
+ # predicted_class_quality = np.argmax(pred_quality, axis=1)
430
+
431
+ # # Define labels for watermelon quality
432
+ # label_map_quality = {
433
+ # 0: "Good",
434
+ # 1: "Mild",
435
+ # 2: "Rotten"
436
+ # }
437
+
438
+ # predicted_quality = label_map_quality[predicted_class_quality[0]]
439
+
440
+ # # Display predictions
441
+ # st.write(f"Fruit Type Detection: {predicted_name}")
442
+ # st.write(f"Fruit Quality Classification: {predicted_quality}")
443
+
444
+ # # If the quality is 'Mild' or 'Rotten', pass the image to the mask detection model
445
+ # if predicted_quality in ["Mild", "Rotten"]:
446
+ # st.write("Passing the image to the mask detection model for damage detection...")
447
+
448
+ # # Load the image again for the mask detection (Detectron2 requires the original image)
449
+ # im = cv2.imdecode(np.fromstring(uploaded_file.read(), np.uint8), 1)
450
+
451
+ # # Run prediction on the image
452
+ # outputs = predictor(im)
453
+
454
+ # # Visualize the predictions
455
+ # v = Visualizer(im[:, :, ::-1], MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=0.8)
456
+ # out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
457
+
458
+ # # Display the output
459
+ # st.image(out.get_image()[:, :, ::-1], caption="Detected Damage", use_column_width=True)
460
+
461
+ # except Exception as e:
462
+ # st.error(f"An error occurred during processing: {str(e)}")
463
+
464
+
465
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
466
 
 
 
 
 
467
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
468
 
 
469
 
 
 
470
 
 
 
471
 
 
 
 
 
 
 
 
472
 
 
 
473
 
 
 
 
474
 
 
 
 
475
 
 
 
 
 
 
 
476
 
 
477
 
 
 
 
478
 
 
 
 
479
 
 
 
480
 
 
 
481
 
 
 
 
482
 
 
 
483
 
 
 
484
 
485
 
486
+
487
+ import gradio as gr
488
+ import numpy as np
489
+ import cv2
490
+ import torch
491
+ from PIL import Image
492
+ from tensorflow.keras.models import load_model
493
+ from tensorflow.keras.preprocessing import image
494
+ from detectron2.engine import DefaultPredictor
495
+ from detectron2.config import get_cfg
496
+ from detectron2.utils.visualizer import Visualizer
497
+ from detectron2.data import MetadataCatalog
498
+
499
+ # Suppress warnings
500
+ import warnings
501
+ import tensorflow as tf
502
+ warnings.filterwarnings("ignore")
503
+ tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
504
+
505
+ # Load models
506
+ model_name = load_model('name_model_inception.h5')
507
+ model_quality = load_model('type_model_inception.h5')
508
+
509
+ # Detectron2 setup
510
+ def load_detectron_model(fruit_name):
511
+ cfg = get_cfg()
512
+ cfg.merge_from_file(f"{fruit_name.lower()}.yaml")
513
+ cfg.MODEL.WEIGHTS = f"{fruit_name}_model.pth"
514
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
515
+ cfg.MODEL.DEVICE = 'cpu'
516
+ predictor = DefaultPredictor(cfg)
517
+ return predictor, cfg
518
+
519
+ # Labels
520
+ label_map_name = {
521
+ 0: "Banana", 1: "Cucumber", 2: "Grape", 3: "Kaki", 4: "Papaya",
522
+ 5: "Peach", 6: "Pear", 7: "Pepper", 8: "Strawberry", 9: "Watermelon",
523
+ 10: "Tomato"
524
+ }
525
+ label_map_quality = {0: "Good", 1: "Mild", 2: "Rotten"}
526
+
527
+ def predict_fruit(img):
528
+ # Preprocess image
529
+ img = Image.fromarray(img.astype('uint8'), 'RGB')
530
+ img = img.resize((224, 224))
531
+ x = image.img_to_array(img)
532
+ x = np.expand_dims(x, axis=0)
533
+ x = x / 255.0
534
+
535
+ # Predict
536
+ pred_name = model_name.predict(x)
537
+ pred_quality = model_quality.predict(x)
538
+
539
+ predicted_name = label_map_name[np.argmax(pred_name, axis=1)[0]]
540
+ predicted_quality = label_map_quality[np.argmax(pred_quality, axis=1)[0]]
541
+
542
+ result = f"Fruit Type: {predicted_name}\nFruit Quality: {predicted_quality}"
543
+
544
+ # Damage detection for specific fruits
545
+ if predicted_name.lower() in ["kaki", "tomato", "strawberry", "pepper", "pear", "peach", "papaya", "watermelon", "grape", "banana", "cucumber"] and predicted_quality in ["Mild", "Rotten"]:
546
+ predictor, cfg = load_detectron_model(predicted_name)
547
+ outputs = predictor(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))
548
+ v = Visualizer(np.array(img), MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=0.8)
549
+ out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
550
+ result_image = out.get_image()
551
+ else:
552
+ result_image = np.array(img)
553
+
554
+ return result, result_image
555
+
556
+ iface = gr.Interface(
557
+ fn=predict_fruit,
558
+ inputs=gr.Image(),
559
+ outputs=[gr.Textbox(), gr.Image()],
560
+ title="Fruit Quality and Damage Detection",
561
+ description="Upload an image of a fruit to detect its type, quality, and potential damage."
562
+ )
563
+
564
+ iface.launch()
565
 
566
 
567