Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
from PIL import Image | |
from ultralytics import YOLO | |
import cv2 | |
from pathlib import Path | |
path = Path(__file__).parent | |
# Model configurations | |
MODEL_CONFIGS = { | |
"Dry Season Form": { | |
"path": path/"models/DSF_Mleda_250.pt", | |
"labels": [ | |
"F1", "F2", "F3", "F4", "F5", "F6", "F7", "F8", "F9", "F10", "F11", "F12", "F13", "F14", "F15", "F16", | |
"H1", "H2", "H3", "H4", "H5", "H6", "H7", "H8", "H9", "H10", "H11", "H12", "H13", "H14", "H15", "H16", "H17", | |
"fs1", "fs2", "fs3", "fs4", "fs5", | |
"hs1", "hs2", "hs3", "hs4", "hs5", "hs6", | |
"sc1", "sc2", | |
"sex", "right", "left", "grey", "black", "white" | |
], | |
"imgsz": 1280 | |
}, | |
"Wet Season Form": { | |
"path": path/"models/WSF_Mleda_200.pt", | |
"labels": [ | |
'F12', 'F14', 'fs1', 'fs2', 'fs3', 'fs4', 'fs5', | |
'H12', 'H14', 'hs1', 'hs2', 'hs3', 'hs4', 'hs5', 'hs6', | |
'white', 'black', 'grey', 'sex', 'blue', 'green', 'red', 'sc1', 'sc2' | |
], | |
"imgsz": 1280 | |
}, | |
"All Season Form": { | |
"path": path/"models/DSF_WSF_Mleda_450.pt", | |
"labels": [ | |
"F1", "F2", "F3", "F4", "F5", "F6", "F7", "F8", "F9", "F10", "F11", "F12", "F13", "F14", "F15", "F16", | |
"H1", "H2", "H3", "H4", "H5", "H6", "H7", "H8", "H9", "H10", "H11", "H12", "H13", "H14", "H15", "H16", "H17", | |
"fs1", "fs2", "fs3", "fs4", "fs5", | |
"hs1", "hs2", "hs3", "hs4", "hs5", "hs6", | |
"sc1", "sc2", | |
"sex", "right", "left", "grey", "black", "white" | |
], | |
"imgsz": 1280 | |
} | |
} | |
def hex_to_bgr(hex_color: str) -> tuple: | |
"""Convert #RRGGBB hex color to BGR tuple.""" | |
hex_color = hex_color.lstrip("#") | |
if len(hex_color) != 6: | |
return (0, 255, 0) # Default to green if invalid | |
r = int(hex_color[0:2], 16) | |
g = int(hex_color[2:4], 16) | |
b = int(hex_color[4:6], 16) | |
return (b, g, r) | |
def load_model(path): | |
"""Load YOLO model from the given path.""" | |
return YOLO(path) | |
def draw_detections(image: np.ndarray, results, labels, keypoint_threshold: float, show_labels: bool, point_size: int, point_color: str, label_size: float) -> np.ndarray: | |
img = image.copy() | |
color_bgr = hex_to_bgr(point_color) | |
for result in results: | |
boxes = result.boxes.xywh.cpu().numpy() | |
cls_ids = result.boxes.cls.int().cpu().numpy() | |
confs = result.boxes.conf.cpu().numpy() | |
kpts_all = result.keypoints.data.cpu().numpy() | |
for (x_c, y_c, w, h), cls_id, conf, kpts in zip(boxes, cls_ids, confs, kpts_all): | |
x1 = int(x_c - w/2); y1 = int(y_c - h/2) | |
x2 = int(x_c + w/2); y2 = int(y_c + h/2) | |
cv2.rectangle(img, (x1, y1), (x2, y2), color=(255,255,0), thickness=2) | |
class_name = result.names[int(cls_id)] | |
text = f"{class_name} {conf:.2f}" | |
(tw, th), _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 1) | |
cv2.rectangle(img, (x1, y1 - th - 4), (x1 + tw, y1), (255,255,0), cv2.FILLED) | |
cv2.putText(img, text, (x1, y1 - 4), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,0,0), 1, cv2.LINE_AA) | |
for i, (x, y, v) in enumerate(kpts): | |
if v > keypoint_threshold and i < len(labels): | |
xi, yi = int(x), int(y) | |
cv2.circle(img, (xi, yi), int(point_size), color_bgr, -1, cv2.LINE_AA) | |
if show_labels: | |
cv2.putText(img, labels[i], (xi + 3, yi + 3), cv2.FONT_HERSHEY_SIMPLEX, label_size, (255,0,0), 2, cv2.LINE_AA) | |
return img | |
def predict_and_annotate(input_image: Image.Image, conf_threshold: float, keypoint_threshold: float, model_choice: str, show_labels: bool, point_size: int, point_color: str, label_size: float): | |
config = MODEL_CONFIGS[model_choice] | |
model = load_model(config["path"]) | |
labels = config["labels"] | |
imgsz = config["imgsz"] | |
img_rgb = np.array(input_image.convert("RGB")) | |
img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR) | |
results = model(img_bgr, conf=conf_threshold, imgsz=imgsz) | |
annotated_bgr = draw_detections(img_bgr, results, labels, keypoint_threshold, show_labels, point_size, point_color, label_size) | |
annotated_rgb = cv2.cvtColor(annotated_bgr, cv2.COLOR_BGR2RGB) | |
return Image.fromarray(annotated_rgb) | |
# Gradio Interface | |
app = gr.Interface( | |
fn=predict_and_annotate, | |
inputs=[ | |
gr.Image(type="pil", label="Upload Image"), | |
gr.Slider(minimum=0.0, maximum=1.0, value=0.25, step=0.01, label="Confidence Threshold"), | |
gr.Slider(minimum=0.0, maximum=1.0, value=0.5, step=0.01, label="Keypoint Visibility Threshold"), | |
gr.Radio(choices=list(MODEL_CONFIGS.keys()), label="Select Model", value="Dry Season Form"), | |
gr.Checkbox(label="Show Keypoint Labels", value=True), | |
gr.Slider(minimum=1, maximum=20, value=8, step=0.1, label="Keypoint Size"), | |
gr.ColorPicker(label="Keypoint Color", value="#00FF00"), | |
gr.Slider(minimum=0.3, maximum=3.0, value=1.0, step=0.1, label="Keypoint Label Font Size") | |
], | |
outputs=gr.Image(type="pil", label="Detection Result", format="png"), | |
title="🦋 Melanitis leda Landmark Identification", | |
description="Upload an image and select the model. Customize detection and keypoint display settings.", | |
flagging_mode="never" | |
) | |
if __name__ == "__main__": | |
app.launch() | |