import gradio as gr import numpy as np from PIL import Image from ultralytics import YOLO import cv2 from pathlib import Path path = Path(__file__).parent # Model configurations MODEL_CONFIGS = { "Dry Season Form": { "path": path/"models/DSF_Mleda_250.pt", "labels": [ "F1", "F2", "F3", "F4", "F5", "F6", "F7", "F8", "F9", "F10", "F11", "F12", "F13", "F14", "F15", "F16", "H1", "H2", "H3", "H4", "H5", "H6", "H7", "H8", "H9", "H10", "H11", "H12", "H13", "H14", "H15", "H16", "H17", "fs1", "fs2", "fs3", "fs4", "fs5", "hs1", "hs2", "hs3", "hs4", "hs5", "hs6", "sc1", "sc2", "sex", "right", "left", "grey", "black", "white" ], "imgsz": 1280 }, "Wet Season Form": { "path": path/"models/WSF_Mleda_200.pt", "labels": [ 'F12', 'F14', 'fs1', 'fs2', 'fs3', 'fs4', 'fs5', 'H12', 'H14', 'hs1', 'hs2', 'hs3', 'hs4', 'hs5', 'hs6', 'white', 'black', 'grey', 'sex', 'blue', 'green', 'red', 'sc1', 'sc2' ], "imgsz": 1280 }, "All Season Form": { "path": path/"models/DSF_WSF_Mleda_450.pt", "labels": [ "F1", "F2", "F3", "F4", "F5", "F6", "F7", "F8", "F9", "F10", "F11", "F12", "F13", "F14", "F15", "F16", "H1", "H2", "H3", "H4", "H5", "H6", "H7", "H8", "H9", "H10", "H11", "H12", "H13", "H14", "H15", "H16", "H17", "fs1", "fs2", "fs3", "fs4", "fs5", "hs1", "hs2", "hs3", "hs4", "hs5", "hs6", "sc1", "sc2", "sex", "right", "left", "grey", "black", "white" ], "imgsz": 1280 } } def hex_to_bgr(hex_color: str) -> tuple: """Convert #RRGGBB hex color to BGR tuple.""" hex_color = hex_color.lstrip("#") if len(hex_color) != 6: return (0, 255, 0) # Default to green if invalid r = int(hex_color[0:2], 16) g = int(hex_color[2:4], 16) b = int(hex_color[4:6], 16) return (b, g, r) def load_model(path): """Load YOLO model from the given path.""" return YOLO(path) def draw_detections(image: np.ndarray, results, labels, keypoint_threshold: float, show_labels: bool, point_size: int, point_color: str, label_size: float) -> np.ndarray: img = image.copy() color_bgr = hex_to_bgr(point_color) for result in results: boxes = result.boxes.xywh.cpu().numpy() cls_ids = result.boxes.cls.int().cpu().numpy() confs = result.boxes.conf.cpu().numpy() kpts_all = result.keypoints.data.cpu().numpy() for (x_c, y_c, w, h), cls_id, conf, kpts in zip(boxes, cls_ids, confs, kpts_all): x1 = int(x_c - w/2); y1 = int(y_c - h/2) x2 = int(x_c + w/2); y2 = int(y_c + h/2) cv2.rectangle(img, (x1, y1), (x2, y2), color=(255,255,0), thickness=2) class_name = result.names[int(cls_id)] text = f"{class_name} {conf:.2f}" (tw, th), _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 1) cv2.rectangle(img, (x1, y1 - th - 4), (x1 + tw, y1), (255,255,0), cv2.FILLED) cv2.putText(img, text, (x1, y1 - 4), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,0,0), 1, cv2.LINE_AA) for i, (x, y, v) in enumerate(kpts): if v > keypoint_threshold and i < len(labels): xi, yi = int(x), int(y) cv2.circle(img, (xi, yi), int(point_size), color_bgr, -1, cv2.LINE_AA) if show_labels: cv2.putText(img, labels[i], (xi + 3, yi + 3), cv2.FONT_HERSHEY_SIMPLEX, label_size, (255,0,0), 2, cv2.LINE_AA) return img def predict_and_annotate(input_image: Image.Image, conf_threshold: float, keypoint_threshold: float, model_choice: str, show_labels: bool, point_size: int, point_color: str, label_size: float): config = MODEL_CONFIGS[model_choice] model = load_model(config["path"]) labels = config["labels"] imgsz = config["imgsz"] img_rgb = np.array(input_image.convert("RGB")) img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR) results = model(img_bgr, conf=conf_threshold, imgsz=imgsz) annotated_bgr = draw_detections(img_bgr, results, labels, keypoint_threshold, show_labels, point_size, point_color, label_size) annotated_rgb = cv2.cvtColor(annotated_bgr, cv2.COLOR_BGR2RGB) return Image.fromarray(annotated_rgb) # Gradio Interface app = gr.Interface( fn=predict_and_annotate, inputs=[ gr.Image(type="pil", label="Upload Image"), gr.Slider(minimum=0.0, maximum=1.0, value=0.25, step=0.01, label="Confidence Threshold"), gr.Slider(minimum=0.0, maximum=1.0, value=0.5, step=0.01, label="Keypoint Visibility Threshold"), gr.Radio(choices=list(MODEL_CONFIGS.keys()), label="Select Model", value="Dry Season Form"), gr.Checkbox(label="Show Keypoint Labels", value=True), gr.Slider(minimum=1, maximum=20, value=8, step=0.1, label="Keypoint Size"), gr.ColorPicker(label="Keypoint Color", value="#00FF00"), gr.Slider(minimum=0.3, maximum=3.0, value=1.0, step=0.1, label="Keypoint Label Font Size") ], outputs=gr.Image(type="pil", label="Detection Result", format="png"), title="🦋 Melanitis leda Landmark Identification", description="Upload an image and select the model. Customize detection and keypoint display settings.", flagging_mode="never" ) if __name__ == "__main__": app.launch()