IsmatS's picture
Update app.py to use local model files
9281224
raw
history blame
5.73 kB
import gradio as gr
from ultralytics import YOLO
import cv2
import numpy as np
from PIL import Image
import os
# Load the model with proper error handling
def load_model():
model_paths = [
'best_model.pt',
'tree_disease_detector.pt',
'./best_model.pt',
'./tree_disease_detector.pt'
]
# Try to load from local files first
for path in model_paths:
if os.path.exists(path):
try:
print(f"Loading model from {path}")
model = YOLO(path)
return model, f"Tree Disease Detection Model ({path})"
except Exception as e:
print(f"Error loading {path}: {e}")
continue
# Fallback to standard YOLOv8s
try:
print("Loading standard YOLOv8s model...")
model = YOLO('yolov8s.pt')
return model, "Standard YOLOv8s Model (Fallback)"
except Exception as e:
print(f"Error loading YOLOv8s: {e}")
return None, "No model available"
# Load model and get status
model, model_status = load_model()
def detect_tree_disease(image, conf_threshold=0.25, iou_threshold=0.45):
"""Detect unhealthy trees in the uploaded image"""
if model is None:
return image, "Error: No model available"
# Convert PIL image to numpy array
image_np = np.array(image)
# Run inference
results = model(image_np, conf=conf_threshold, iou=iou_threshold)
# Get annotated image directly from results
annotated_img = results[0].plot()
annotated_img = cv2.cvtColor(annotated_img, cv2.COLOR_BGR2RGB)
annotated_img = Image.fromarray(annotated_img)
# Extract detections
detections = []
for result in results:
boxes = result.boxes
if boxes is not None:
for box in boxes:
detection = {
'confidence': float(box.conf[0]),
'bbox': box.xyxy[0].tolist(),
'class': 'unhealthy'
}
detections.append(detection)
# Create detection summary
is_custom_model = "Tree Disease Detection Model" in model_status
if is_custom_model:
summary = f"Detected {len(detections)} unhealthy tree(s)\n\n"
for i, det in enumerate(detections, 1):
summary += f"Tree {i}: Confidence {det['confidence']:.2f}\n"
else:
summary = f"Using {model_status}\n"
summary += f"Detected {len(detections)} object(s)\n\n"
for i, det in enumerate(detections, 1):
summary += f"Object {i}: Confidence {det['confidence']:.2f}\n"
summary += f"\nModel Status: {model_status}"
return annotated_img, summary
# Create example images (tree images)
example_images = [
["https://images.pexels.com/photos/1632790/pexels-photo-1632790.jpeg", 0.25, 0.45],
["https://images.pexels.com/photos/38537/woodland-road-falling-leaf-natural-38537.jpeg", 0.25, 0.45],
["https://upload.wikimedia.org/wikipedia/commons/thumb/e/eb/Ash_Tree_-_geograph.org.uk_-_590710.jpg/640px-Ash_Tree_-_geograph.org.uk_-_590710.jpg", 0.25, 0.45],
]
# Create Gradio interface
with gr.Blocks(title="Tree Disease Detection") as demo:
gr.Markdown(f"""
# 🌳 Tree Disease Detection with YOLOv8
This model detects unhealthy/diseased trees in aerial UAV imagery.
Upload an image or use one of the examples below to detect diseased trees.
**Current Model**: {model_status}
""")
if "Fallback" in model_status:
gr.Markdown("""
⚠️ **Note**: Using a fallback model. Detection will work but won't be specific to tree diseases.
""")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Upload Image")
conf_threshold = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.25,
step=0.05,
label="Confidence Threshold"
)
iou_threshold = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.45,
step=0.05,
label="IoU Threshold"
)
detect_button = gr.Button("Detect Tree Disease", variant="primary")
with gr.Column():
output_image = gr.Image(type="pil", label="Detection Results")
detection_summary = gr.Textbox(label="Detection Summary", lines=10)
# Set up event handler
detect_button.click(
fn=detect_tree_disease,
inputs=[input_image, conf_threshold, iou_threshold],
outputs=[output_image, detection_summary]
)
# Add examples
gr.Examples(
examples=example_images,
inputs=[input_image, conf_threshold, iou_threshold],
outputs=[output_image, detection_summary],
fn=detect_tree_disease,
cache_examples=False,
)
gr.Markdown("""
## About this Model
- **Architecture**: YOLOv8s
- **Dataset**: [PDT Dataset](https://huggingface.co/datasets/qwer0213/PDT_dataset)
- **mAP50**: 0.933
- **mAP50-95**: 0.659
- **Precision**: 0.878
- **Recall**: 0.863
- **Classes**: 1 (unhealthy trees)
## Usage Tips
- This model works best with aerial/UAV imagery
- Optimal input resolution: 640x640 pixels
- Adjust confidence threshold to filter detections
- Lower IoU threshold for overlapping trees
[Model Card](https://huggingface.co/IsmatS/crop_desease_detection) |
[Dataset](https://huggingface.co/datasets/qwer0213/PDT_dataset)
""")
# Launch the app
if __name__ == "__main__":
demo.launch()