Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| from ultralytics import YOLO | |
| import cv2 | |
| from pathlib import Path | |
| path = Path(__file__).parent | |
| # Model configurations | |
| MODEL_CONFIGS = { | |
| "Dry Season Form": { | |
| "path": path/"models/DSF_Mleda_250.pt", | |
| "labels": [ | |
| "F1", "F2", "F3", "F4", "F5", "F6", "F7", "F8", "F9", "F10", "F11", "F12", "F13", "F14", "F15", "F16", | |
| "H1", "H2", "H3", "H4", "H5", "H6", "H7", "H8", "H9", "H10", "H11", "H12", "H13", "H14", "H15", "H16", "H17", | |
| "fs1", "fs2", "fs3", "fs4", "fs5", | |
| "hs1", "hs2", "hs3", "hs4", "hs5", "hs6", | |
| "sc1", "sc2", | |
| "sex", "right", "left", "grey", "black", "white" | |
| ], | |
| "imgsz": 1280 | |
| }, | |
| "Wet Season Form": { | |
| "path": path/"models/WSF_Mleda_200.pt", | |
| "labels": [ | |
| 'F12', 'F14', 'fs1', 'fs2', 'fs3', 'fs4', 'fs5', | |
| 'H12', 'H14', 'hs1', 'hs2', 'hs3', 'hs4', 'hs5', 'hs6', | |
| 'white', 'black', 'grey', 'sex', 'blue', 'green', 'red', 'sc1', 'sc2' | |
| ], | |
| "imgsz": 1280 | |
| }, | |
| "All Season Form": { | |
| "path": path/"models/DSF_WSF_Mleda_450.pt", | |
| "labels": [ | |
| "F1", "F2", "F3", "F4", "F5", "F6", "F7", "F8", "F9", "F10", "F11", "F12", "F13", "F14", "F15", "F16", | |
| "H1", "H2", "H3", "H4", "H5", "H6", "H7", "H8", "H9", "H10", "H11", "H12", "H13", "H14", "H15", "H16", "H17", | |
| "fs1", "fs2", "fs3", "fs4", "fs5", | |
| "hs1", "hs2", "hs3", "hs4", "hs5", "hs6", | |
| "sc1", "sc2", | |
| "sex", "right", "left", "grey", "black", "white" | |
| ], | |
| "imgsz": 1280 | |
| } | |
| } | |
| def hex_to_bgr(hex_color: str) -> tuple: | |
| """Convert #RRGGBB hex color to BGR tuple.""" | |
| hex_color = hex_color.lstrip("#") | |
| if len(hex_color) != 6: | |
| return (0, 255, 0) # Default to green if invalid | |
| r = int(hex_color[0:2], 16) | |
| g = int(hex_color[2:4], 16) | |
| b = int(hex_color[4:6], 16) | |
| return (b, g, r) | |
| def load_model(path): | |
| """Load YOLO model from the given path.""" | |
| return YOLO(path) | |
| def draw_detections(image: np.ndarray, results, labels, keypoint_threshold: float, show_labels: bool, point_size: int, point_color: str, label_size: float) -> np.ndarray: | |
| img = image.copy() | |
| color_bgr = hex_to_bgr(point_color) | |
| for result in results: | |
| boxes = result.boxes.xywh.cpu().numpy() | |
| cls_ids = result.boxes.cls.int().cpu().numpy() | |
| confs = result.boxes.conf.cpu().numpy() | |
| kpts_all = result.keypoints.data.cpu().numpy() | |
| for (x_c, y_c, w, h), cls_id, conf, kpts in zip(boxes, cls_ids, confs, kpts_all): | |
| x1 = int(x_c - w/2); y1 = int(y_c - h/2) | |
| x2 = int(x_c + w/2); y2 = int(y_c + h/2) | |
| cv2.rectangle(img, (x1, y1), (x2, y2), color=(255,255,0), thickness=2) | |
| class_name = result.names[int(cls_id)] | |
| text = f"{class_name} {conf:.2f}" | |
| (tw, th), _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 1) | |
| cv2.rectangle(img, (x1, y1 - th - 4), (x1 + tw, y1), (255,255,0), cv2.FILLED) | |
| cv2.putText(img, text, (x1, y1 - 4), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,0,0), 1, cv2.LINE_AA) | |
| for i, (x, y, v) in enumerate(kpts): | |
| if v > keypoint_threshold and i < len(labels): | |
| xi, yi = int(x), int(y) | |
| cv2.circle(img, (xi, yi), int(point_size), color_bgr, -1, cv2.LINE_AA) | |
| if show_labels: | |
| cv2.putText(img, labels[i], (xi + 3, yi + 3), cv2.FONT_HERSHEY_SIMPLEX, label_size, (255,0,0), 2, cv2.LINE_AA) | |
| return img | |
| def predict_and_annotate(input_image: Image.Image, conf_threshold: float, keypoint_threshold: float, model_choice: str, show_labels: bool, point_size: int, point_color: str, label_size: float): | |
| config = MODEL_CONFIGS[model_choice] | |
| model = load_model(config["path"]) | |
| labels = config["labels"] | |
| imgsz = config["imgsz"] | |
| img_rgb = np.array(input_image.convert("RGB")) | |
| img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR) | |
| results = model(img_bgr, conf=conf_threshold, imgsz=imgsz) | |
| annotated_bgr = draw_detections(img_bgr, results, labels, keypoint_threshold, show_labels, point_size, point_color, label_size) | |
| annotated_rgb = cv2.cvtColor(annotated_bgr, cv2.COLOR_BGR2RGB) | |
| return Image.fromarray(annotated_rgb) | |
| # Gradio Interface | |
| app = gr.Interface( | |
| fn=predict_and_annotate, | |
| inputs=[ | |
| gr.Image(type="pil", label="Upload Image"), | |
| gr.Slider(minimum=0.0, maximum=1.0, value=0.25, step=0.01, label="Confidence Threshold"), | |
| gr.Slider(minimum=0.0, maximum=1.0, value=0.5, step=0.01, label="Keypoint Visibility Threshold"), | |
| gr.Radio(choices=list(MODEL_CONFIGS.keys()), label="Select Model", value="Dry Season Form"), | |
| gr.Checkbox(label="Show Keypoint Labels", value=True), | |
| gr.Slider(minimum=1, maximum=20, value=8, step=0.1, label="Keypoint Size"), | |
| gr.ColorPicker(label="Keypoint Color", value="#00FF00"), | |
| gr.Slider(minimum=0.3, maximum=3.0, value=1.0, step=0.1, label="Keypoint Label Font Size") | |
| ], | |
| outputs=gr.Image(type="pil", label="Detection Result", format="png"), | |
| title="🦋 Melanitis leda Landmark Identification", | |
| description="Upload an image and select the model. Customize detection and keypoint display settings.", | |
| flagging_mode="never" | |
| ) | |
| if __name__ == "__main__": | |
| app.launch() | |