Spaces:
Sleeping
Sleeping
File size: 5,464 Bytes
c9da29e f6cfe33 c9da29e f6cfe33 c9da29e f6cfe33 c9da29e f6cfe33 c9da29e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import gradio as gr
import numpy as np
from PIL import Image
from ultralytics import YOLO
import cv2
from pathlib import Path
path = Path(__file__).parent
# Model configurations
MODEL_CONFIGS = {
"Dry Season Form": {
"path": path/"models/DSF_Mleda_250.pt",
"labels": [
"F1", "F2", "F3", "F4", "F5", "F6", "F7", "F8", "F9", "F10", "F11", "F12", "F13", "F14", "F15", "F16",
"H1", "H2", "H3", "H4", "H5", "H6", "H7", "H8", "H9", "H10", "H11", "H12", "H13", "H14", "H15", "H16", "H17",
"fs1", "fs2", "fs3", "fs4", "fs5",
"hs1", "hs2", "hs3", "hs4", "hs5", "hs6",
"sc1", "sc2",
"sex", "right", "left", "grey", "black", "white"
],
"imgsz": 1280
},
"Wet Season Form": {
"path": path/"models/WSF_Mleda_200.pt",
"labels": [
'F12', 'F14', 'fs1', 'fs2', 'fs3', 'fs4', 'fs5',
'H12', 'H14', 'hs1', 'hs2', 'hs3', 'hs4', 'hs5', 'hs6',
'white', 'black', 'grey', 'sex', 'blue', 'green', 'red', 'sc1', 'sc2'
],
"imgsz": 1280
},
"All Season Form": {
"path": path/"models/DSF_WSF_Mleda_450.pt",
"labels": [
"F1", "F2", "F3", "F4", "F5", "F6", "F7", "F8", "F9", "F10", "F11", "F12", "F13", "F14", "F15", "F16",
"H1", "H2", "H3", "H4", "H5", "H6", "H7", "H8", "H9", "H10", "H11", "H12", "H13", "H14", "H15", "H16", "H17",
"fs1", "fs2", "fs3", "fs4", "fs5",
"hs1", "hs2", "hs3", "hs4", "hs5", "hs6",
"sc1", "sc2",
"sex", "right", "left", "grey", "black", "white"
],
"imgsz": 1280
}
}
def hex_to_bgr(hex_color: str) -> tuple:
"""Convert #RRGGBB hex color to BGR tuple."""
hex_color = hex_color.lstrip("#")
if len(hex_color) != 6:
return (0, 255, 0) # Default to green if invalid
r = int(hex_color[0:2], 16)
g = int(hex_color[2:4], 16)
b = int(hex_color[4:6], 16)
return (b, g, r)
def load_model(path):
"""Load YOLO model from the given path."""
return YOLO(path)
def draw_detections(image: np.ndarray, results, labels, keypoint_threshold: float, show_labels: bool, point_size: int, point_color: str, label_size: float) -> np.ndarray:
img = image.copy()
color_bgr = hex_to_bgr(point_color)
for result in results:
boxes = result.boxes.xywh.cpu().numpy()
cls_ids = result.boxes.cls.int().cpu().numpy()
confs = result.boxes.conf.cpu().numpy()
kpts_all = result.keypoints.data.cpu().numpy()
for (x_c, y_c, w, h), cls_id, conf, kpts in zip(boxes, cls_ids, confs, kpts_all):
x1 = int(x_c - w/2); y1 = int(y_c - h/2)
x2 = int(x_c + w/2); y2 = int(y_c + h/2)
cv2.rectangle(img, (x1, y1), (x2, y2), color=(255,255,0), thickness=2)
class_name = result.names[int(cls_id)]
text = f"{class_name} {conf:.2f}"
(tw, th), _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 1)
cv2.rectangle(img, (x1, y1 - th - 4), (x1 + tw, y1), (255,255,0), cv2.FILLED)
cv2.putText(img, text, (x1, y1 - 4), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,0,0), 1, cv2.LINE_AA)
for i, (x, y, v) in enumerate(kpts):
if v > keypoint_threshold and i < len(labels):
xi, yi = int(x), int(y)
cv2.circle(img, (xi, yi), int(point_size), color_bgr, -1, cv2.LINE_AA)
if show_labels:
cv2.putText(img, labels[i], (xi + 3, yi + 3), cv2.FONT_HERSHEY_SIMPLEX, label_size, (255,0,0), 2, cv2.LINE_AA)
return img
def predict_and_annotate(input_image: Image.Image, conf_threshold: float, keypoint_threshold: float, model_choice: str, show_labels: bool, point_size: int, point_color: str, label_size: float):
config = MODEL_CONFIGS[model_choice]
model = load_model(config["path"])
labels = config["labels"]
imgsz = config["imgsz"]
img_rgb = np.array(input_image.convert("RGB"))
img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
results = model(img_bgr, conf=conf_threshold, imgsz=imgsz)
annotated_bgr = draw_detections(img_bgr, results, labels, keypoint_threshold, show_labels, point_size, point_color, label_size)
annotated_rgb = cv2.cvtColor(annotated_bgr, cv2.COLOR_BGR2RGB)
return Image.fromarray(annotated_rgb)
# Gradio Interface
app = gr.Interface(
fn=predict_and_annotate,
inputs=[
gr.Image(type="pil", label="Upload Image"),
gr.Slider(minimum=0.0, maximum=1.0, value=0.25, step=0.01, label="Confidence Threshold"),
gr.Slider(minimum=0.0, maximum=1.0, value=0.5, step=0.01, label="Keypoint Visibility Threshold"),
gr.Radio(choices=list(MODEL_CONFIGS.keys()), label="Select Model", value="Dry Season Form"),
gr.Checkbox(label="Show Keypoint Labels", value=True),
gr.Slider(minimum=1, maximum=20, value=8, step=0.1, label="Keypoint Size"),
gr.ColorPicker(label="Keypoint Color", value="#00FF00"),
gr.Slider(minimum=0.3, maximum=3.0, value=1.0, step=0.1, label="Keypoint Label Font Size")
],
outputs=gr.Image(type="pil", label="Detection Result", format="png"),
title="🦋 Melanitis leda Landmark Identification",
description="Upload an image and select the model. Customize detection and keypoint display settings.",
flagging_mode="never"
)
if __name__ == "__main__":
app.launch()
|