Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -341,127 +341,227 @@
|
|
341 |
# ///////////////////////////////////Working
|
342 |
|
343 |
|
344 |
-
import streamlit as st
|
345 |
-
import numpy as np
|
346 |
-
import cv2
|
347 |
-
import warnings
|
348 |
-
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
349 |
|
350 |
-
# Suppress warnings
|
351 |
-
warnings.filterwarnings("ignore", category=FutureWarning)
|
352 |
-
warnings.filterwarnings("ignore", category=UserWarning)
|
353 |
-
|
354 |
-
# Try importing TensorFlow
|
355 |
-
try:
|
356 |
-
from tensorflow.keras.models import load_model
|
357 |
-
from tensorflow.keras.preprocessing import image
|
358 |
-
except ImportError:
|
359 |
-
st.error("Failed to import TensorFlow. Please make sure it's installed correctly.")
|
360 |
-
|
361 |
-
# Try importing PyTorch and Detectron2
|
362 |
-
try:
|
363 |
-
import torch
|
364 |
-
import detectron2
|
365 |
-
except ImportError:
|
366 |
-
with st.spinner("Installing PyTorch and Detectron2..."):
|
367 |
-
os.system("pip install torch torchvision")
|
368 |
-
os.system("pip install 'git+https://github.com/facebookresearch/detectron2.git'")
|
369 |
-
|
370 |
-
import torch
|
371 |
-
import detectron2
|
372 |
|
373 |
-
from detectron2.engine import DefaultPredictor
|
374 |
-
from detectron2.config import get_cfg
|
375 |
-
from detectron2.utils.visualizer import Visualizer
|
376 |
-
from detectron2.data import MetadataCatalog
|
377 |
|
378 |
-
# Load the trained models
|
379 |
-
@st.cache_resource
|
380 |
-
def load_models():
|
381 |
-
try:
|
382 |
-
model_path_name = 'name_model_inception.h5'
|
383 |
-
model_path_quality = 'type_model_inception.h5'
|
384 |
-
model_name = load_model(model_path_name)
|
385 |
-
model_quality = load_model(model_path_quality)
|
386 |
-
return model_name, model_quality
|
387 |
-
except Exception as e:
|
388 |
-
st.error(f"Failed to load models: {str(e)}")
|
389 |
-
return None, None
|
390 |
-
|
391 |
-
model_name, model_quality = load_models()
|
392 |
-
|
393 |
-
# Setup Detectron2 configuration for watermelon
|
394 |
-
@st.cache_resource
|
395 |
-
def load_detectron_model():
|
396 |
-
cfg = get_cfg()
|
397 |
-
cfg.merge_from_file("watermelon.yaml")
|
398 |
-
cfg.MODEL.WEIGHTS = "Watermelon_model.pth"
|
399 |
-
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
|
400 |
-
cfg.MODEL.DEVICE = 'cpu' # Use CPU for inference
|
401 |
-
predictor = DefaultPredictor(cfg)
|
402 |
-
return predictor, cfg
|
403 |
|
404 |
-
predictor, cfg = load_detectron_model()
|
405 |
|
406 |
-
# Streamlit app title
|
407 |
-
st.title("Watermelon Quality and Damage Detection")
|
408 |
|
409 |
-
# Upload image
|
410 |
-
uploaded_file = st.file_uploader("Choose a watermelon image...", type=["jpg", "jpeg", "png"])
|
411 |
|
412 |
-
if uploaded_file is not None:
|
413 |
-
try:
|
414 |
-
# Load the image
|
415 |
-
img = image.load_img(uploaded_file, target_size=(224, 224))
|
416 |
-
img_array = image.img_to_array(img)
|
417 |
-
img_array = np.expand_dims(img_array, axis=0)
|
418 |
-
img_array /= 255.0
|
419 |
|
420 |
-
# Display uploaded image
|
421 |
-
st.image(uploaded_file, caption="Uploaded Image", use_column_width=True)
|
422 |
|
423 |
-
# Predict watermelon name
|
424 |
-
pred_name = model_name.predict(img_array)
|
425 |
-
predicted_name = 'Watermelon'
|
426 |
|
427 |
-
# Predict watermelon quality
|
428 |
-
pred_quality = model_quality.predict(img_array)
|
429 |
-
predicted_class_quality = np.argmax(pred_quality, axis=1)
|
430 |
|
431 |
-
# Define labels for watermelon quality
|
432 |
-
label_map_quality = {
|
433 |
-
0: "Good",
|
434 |
-
1: "Mild",
|
435 |
-
2: "Rotten"
|
436 |
-
}
|
437 |
|
438 |
-
predicted_quality = label_map_quality[predicted_class_quality[0]]
|
439 |
|
440 |
-
# Display predictions
|
441 |
-
st.write(f"Fruit Type Detection: {predicted_name}")
|
442 |
-
st.write(f"Fruit Quality Classification: {predicted_quality}")
|
443 |
|
444 |
-
# If the quality is 'Mild' or 'Rotten', pass the image to the mask detection model
|
445 |
-
if predicted_quality in ["Mild", "Rotten"]:
|
446 |
-
st.write("Passing the image to the mask detection model for damage detection...")
|
447 |
|
448 |
-
# Load the image again for the mask detection (Detectron2 requires the original image)
|
449 |
-
im = cv2.imdecode(np.fromstring(uploaded_file.read(), np.uint8), 1)
|
450 |
|
451 |
-
# Run prediction on the image
|
452 |
-
outputs = predictor(im)
|
453 |
|
454 |
-
# Visualize the predictions
|
455 |
-
v = Visualizer(im[:, :, ::-1], MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=0.8)
|
456 |
-
out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
|
457 |
|
458 |
-
# Display the output
|
459 |
-
st.image(out.get_image()[:, :, ::-1], caption="Detected Damage", use_column_width=True)
|
460 |
|
461 |
-
except Exception as e:
|
462 |
-
st.error(f"An error occurred during processing: {str(e)}")
|
463 |
|
464 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
465 |
|
466 |
|
467 |
|
|
|
341 |
# ///////////////////////////////////Working
|
342 |
|
343 |
|
344 |
+
# import streamlit as st
|
345 |
+
# import numpy as np
|
346 |
+
# import cv2
|
347 |
+
# import warnings
|
348 |
+
# import os
|
349 |
+
|
350 |
+
# # Suppress warnings
|
351 |
+
# warnings.filterwarnings("ignore", category=FutureWarning)
|
352 |
+
# warnings.filterwarnings("ignore", category=UserWarning)
|
353 |
+
|
354 |
+
# # Try importing TensorFlow
|
355 |
+
# try:
|
356 |
+
# from tensorflow.keras.models import load_model
|
357 |
+
# from tensorflow.keras.preprocessing import image
|
358 |
+
# except ImportError:
|
359 |
+
# st.error("Failed to import TensorFlow. Please make sure it's installed correctly.")
|
360 |
+
|
361 |
+
# # Try importing PyTorch and Detectron2
|
362 |
+
# try:
|
363 |
+
# import torch
|
364 |
+
# import detectron2
|
365 |
+
# except ImportError:
|
366 |
+
# with st.spinner("Installing PyTorch and Detectron2..."):
|
367 |
+
# os.system("pip install torch torchvision")
|
368 |
+
# os.system("pip install 'git+https://github.com/facebookresearch/detectron2.git'")
|
369 |
+
|
370 |
+
# import torch
|
371 |
+
# import detectron2
|
372 |
+
|
373 |
+
# from detectron2.engine import DefaultPredictor
|
374 |
+
# from detectron2.config import get_cfg
|
375 |
+
# from detectron2.utils.visualizer import Visualizer
|
376 |
+
# from detectron2.data import MetadataCatalog
|
377 |
+
|
378 |
+
# # Load the trained models
|
379 |
+
# @st.cache_resource
|
380 |
+
# def load_models():
|
381 |
+
# try:
|
382 |
+
# model_path_name = 'name_model_inception.h5'
|
383 |
+
# model_path_quality = 'type_model_inception.h5'
|
384 |
+
# model_name = load_model(model_path_name)
|
385 |
+
# model_quality = load_model(model_path_quality)
|
386 |
+
# return model_name, model_quality
|
387 |
+
# except Exception as e:
|
388 |
+
# st.error(f"Failed to load models: {str(e)}")
|
389 |
+
# return None, None
|
390 |
+
|
391 |
+
# model_name, model_quality = load_models()
|
392 |
+
|
393 |
+
# # Setup Detectron2 configuration for watermelon
|
394 |
+
# @st.cache_resource
|
395 |
+
# def load_detectron_model():
|
396 |
+
# cfg = get_cfg()
|
397 |
+
# cfg.merge_from_file("watermelon.yaml")
|
398 |
+
# cfg.MODEL.WEIGHTS = "Watermelon_model.pth"
|
399 |
+
# cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
|
400 |
+
# cfg.MODEL.DEVICE = 'cpu' # Use CPU for inference
|
401 |
+
# predictor = DefaultPredictor(cfg)
|
402 |
+
# return predictor, cfg
|
403 |
+
|
404 |
+
# predictor, cfg = load_detectron_model()
|
405 |
+
|
406 |
+
# # Streamlit app title
|
407 |
+
# st.title("Watermelon Quality and Damage Detection")
|
408 |
+
|
409 |
+
# # Upload image
|
410 |
+
# uploaded_file = st.file_uploader("Choose a watermelon image...", type=["jpg", "jpeg", "png"])
|
411 |
+
|
412 |
+
# if uploaded_file is not None:
|
413 |
+
# try:
|
414 |
+
# # Load the image
|
415 |
+
# img = image.load_img(uploaded_file, target_size=(224, 224))
|
416 |
+
# img_array = image.img_to_array(img)
|
417 |
+
# img_array = np.expand_dims(img_array, axis=0)
|
418 |
+
# img_array /= 255.0
|
419 |
+
|
420 |
+
# # Display uploaded image
|
421 |
+
# st.image(uploaded_file, caption="Uploaded Image", use_column_width=True)
|
422 |
+
|
423 |
+
# # Predict watermelon name
|
424 |
+
# pred_name = model_name.predict(img_array)
|
425 |
+
# predicted_name = 'Watermelon'
|
426 |
+
|
427 |
+
# # Predict watermelon quality
|
428 |
+
# pred_quality = model_quality.predict(img_array)
|
429 |
+
# predicted_class_quality = np.argmax(pred_quality, axis=1)
|
430 |
+
|
431 |
+
# # Define labels for watermelon quality
|
432 |
+
# label_map_quality = {
|
433 |
+
# 0: "Good",
|
434 |
+
# 1: "Mild",
|
435 |
+
# 2: "Rotten"
|
436 |
+
# }
|
437 |
+
|
438 |
+
# predicted_quality = label_map_quality[predicted_class_quality[0]]
|
439 |
+
|
440 |
+
# # Display predictions
|
441 |
+
# st.write(f"Fruit Type Detection: {predicted_name}")
|
442 |
+
# st.write(f"Fruit Quality Classification: {predicted_quality}")
|
443 |
+
|
444 |
+
# # If the quality is 'Mild' or 'Rotten', pass the image to the mask detection model
|
445 |
+
# if predicted_quality in ["Mild", "Rotten"]:
|
446 |
+
# st.write("Passing the image to the mask detection model for damage detection...")
|
447 |
+
|
448 |
+
# # Load the image again for the mask detection (Detectron2 requires the original image)
|
449 |
+
# im = cv2.imdecode(np.fromstring(uploaded_file.read(), np.uint8), 1)
|
450 |
+
|
451 |
+
# # Run prediction on the image
|
452 |
+
# outputs = predictor(im)
|
453 |
+
|
454 |
+
# # Visualize the predictions
|
455 |
+
# v = Visualizer(im[:, :, ::-1], MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=0.8)
|
456 |
+
# out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
|
457 |
+
|
458 |
+
# # Display the output
|
459 |
+
# st.image(out.get_image()[:, :, ::-1], caption="Detected Damage", use_column_width=True)
|
460 |
+
|
461 |
+
# except Exception as e:
|
462 |
+
# st.error(f"An error occurred during processing: {str(e)}")
|
463 |
+
|
464 |
+
|
465 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
466 |
|
|
|
|
|
|
|
|
|
467 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
468 |
|
|
|
469 |
|
|
|
|
|
470 |
|
|
|
|
|
471 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
472 |
|
|
|
|
|
473 |
|
|
|
|
|
|
|
474 |
|
|
|
|
|
|
|
475 |
|
|
|
|
|
|
|
|
|
|
|
|
|
476 |
|
|
|
477 |
|
|
|
|
|
|
|
478 |
|
|
|
|
|
|
|
479 |
|
|
|
|
|
480 |
|
|
|
|
|
481 |
|
|
|
|
|
|
|
482 |
|
|
|
|
|
483 |
|
|
|
|
|
484 |
|
485 |
|
486 |
+
|
487 |
+
import gradio as gr
|
488 |
+
import numpy as np
|
489 |
+
import cv2
|
490 |
+
import torch
|
491 |
+
from PIL import Image
|
492 |
+
from tensorflow.keras.models import load_model
|
493 |
+
from tensorflow.keras.preprocessing import image
|
494 |
+
from detectron2.engine import DefaultPredictor
|
495 |
+
from detectron2.config import get_cfg
|
496 |
+
from detectron2.utils.visualizer import Visualizer
|
497 |
+
from detectron2.data import MetadataCatalog
|
498 |
+
|
499 |
+
# Suppress warnings
|
500 |
+
import warnings
|
501 |
+
import tensorflow as tf
|
502 |
+
warnings.filterwarnings("ignore")
|
503 |
+
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
|
504 |
+
|
505 |
+
# Load models
|
506 |
+
model_name = load_model('name_model_inception.h5')
|
507 |
+
model_quality = load_model('type_model_inception.h5')
|
508 |
+
|
509 |
+
# Detectron2 setup
|
510 |
+
def load_detectron_model(fruit_name):
|
511 |
+
cfg = get_cfg()
|
512 |
+
cfg.merge_from_file(f"{fruit_name.lower()}.yaml")
|
513 |
+
cfg.MODEL.WEIGHTS = f"{fruit_name}_model.pth"
|
514 |
+
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
|
515 |
+
cfg.MODEL.DEVICE = 'cpu'
|
516 |
+
predictor = DefaultPredictor(cfg)
|
517 |
+
return predictor, cfg
|
518 |
+
|
519 |
+
# Labels
|
520 |
+
label_map_name = {
|
521 |
+
0: "Banana", 1: "Cucumber", 2: "Grape", 3: "Kaki", 4: "Papaya",
|
522 |
+
5: "Peach", 6: "Pear", 7: "Pepper", 8: "Strawberry", 9: "Watermelon",
|
523 |
+
10: "Tomato"
|
524 |
+
}
|
525 |
+
label_map_quality = {0: "Good", 1: "Mild", 2: "Rotten"}
|
526 |
+
|
527 |
+
def predict_fruit(img):
|
528 |
+
# Preprocess image
|
529 |
+
img = Image.fromarray(img.astype('uint8'), 'RGB')
|
530 |
+
img = img.resize((224, 224))
|
531 |
+
x = image.img_to_array(img)
|
532 |
+
x = np.expand_dims(x, axis=0)
|
533 |
+
x = x / 255.0
|
534 |
+
|
535 |
+
# Predict
|
536 |
+
pred_name = model_name.predict(x)
|
537 |
+
pred_quality = model_quality.predict(x)
|
538 |
+
|
539 |
+
predicted_name = label_map_name[np.argmax(pred_name, axis=1)[0]]
|
540 |
+
predicted_quality = label_map_quality[np.argmax(pred_quality, axis=1)[0]]
|
541 |
+
|
542 |
+
result = f"Fruit Type: {predicted_name}\nFruit Quality: {predicted_quality}"
|
543 |
+
|
544 |
+
# Damage detection for specific fruits
|
545 |
+
if predicted_name.lower() in ["kaki", "tomato", "strawberry", "pepper", "pear", "peach", "papaya", "watermelon", "grape", "banana", "cucumber"] and predicted_quality in ["Mild", "Rotten"]:
|
546 |
+
predictor, cfg = load_detectron_model(predicted_name)
|
547 |
+
outputs = predictor(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))
|
548 |
+
v = Visualizer(np.array(img), MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=0.8)
|
549 |
+
out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
|
550 |
+
result_image = out.get_image()
|
551 |
+
else:
|
552 |
+
result_image = np.array(img)
|
553 |
+
|
554 |
+
return result, result_image
|
555 |
+
|
556 |
+
iface = gr.Interface(
|
557 |
+
fn=predict_fruit,
|
558 |
+
inputs=gr.Image(),
|
559 |
+
outputs=[gr.Textbox(), gr.Image()],
|
560 |
+
title="Fruit Quality and Damage Detection",
|
561 |
+
description="Upload an image of a fruit to detect its type, quality, and potential damage."
|
562 |
+
)
|
563 |
+
|
564 |
+
iface.launch()
|
565 |
|
566 |
|
567 |
|