import gradio as gr import numpy as np from PIL import Image from ultralytics import YOLO import cv2 from pathlib import Path import xml.etree.ElementTree as ET import tempfile import zipfile # Base path for model files path = Path(__file__).parent # Model configurations MODEL_CONFIGS = { "Dry Season Form": { "path": path / "models/DSF_Mleda_250.pt", "labels": [ "F1", "F2", "F3", "F4", "F5", "F6", "F7", "F8", "F9", "F10", "F11", "F12", "F13", "F14", "F15", "F16", "H1", "H2", "H3", "H4", "H5", "H6", "H7", "H8", "H9", "H10", "H11", "H12", "H13", "H14", "H15", "H16", "H17", "fs1", "fs2", "fs3", "fs4", "fs5", "hs1", "hs2", "hs3", "hs4", "hs5", "hs6", "sc1", "sc2", "sex", "right", "left", "grey", "black", "white" ], "imgsz": 1280 }, "Wet Season Form": { "path": path / "models/WSF_Mleda_200.pt", "labels": [ 'F12', 'F14', 'fs1', 'fs2', 'fs3', 'fs4', 'fs5', 'H12', 'H14', 'hs1', 'hs2', 'hs3', 'hs4', 'hs5', 'hs6', 'white', 'black', 'grey', 'sex', 'blue', 'green', 'red', 'sc1', 'sc2' ], "imgsz": 1280 }, "All Season Form": { "path": path / "models/DSF_WSF_Mleda_450.pt", "labels": [ "F1", "F2", "F3", "F4", "F5", "F6", "F7", "F8", "F9", "F10", "F11", "F12", "F13", "F14", "F15", "F16", "H1", "H2", "H3", "H4", "H5", "H6", "H7", "H8", "H9", "H10", "H11", "H12", "H13", "H14", "H15", "H16", "H17", "fs1", "fs2", "fs3", "fs4", "fs5", "hs1", "hs2", "hs3", "hs4", "hs5", "hs6", "sc1", "sc2", "sex", "right", "left", "grey", "black", "white" ], "imgsz": 1280 } } # Directory for XML annotations ANNOTATIONS_DIR = Path(tempfile.gettempdir()) / "annotations" ANNOTATIONS_DIR.mkdir(parents=True, exist_ok=True) def hex_to_bgr(hex_color: str) -> tuple: """Convert #RRGGBB hex color to BGR tuple.""" hex_color = hex_color.lstrip("#") if len(hex_color) != 6: return (0, 255, 0) # Default to green if invalid r = int(hex_color[0:2], 16) g = int(hex_color[2:4], 16) b = int(hex_color[4:6], 16) return (b, g, r) def load_model(path: Path): """Load YOLO model from the given path.""" return YOLO(str(path)) def draw_detections(image: np.ndarray, results, labels, keypoint_threshold: float, show_labels: bool, point_size: int, point_color: str, label_size: float) -> np.ndarray: """Draw bounding boxes, keypoints, and labels on the image.""" img = image.copy() color_bgr = hex_to_bgr(point_color) for result in results: boxes = result.boxes.xywh.cpu().numpy() cls_ids = result.boxes.cls.int().cpu().numpy() confs = result.boxes.conf.cpu().numpy() kpts_all = result.keypoints.data.cpu().numpy() for (x_c, y_c, w, h), cls_id, conf, kpts in zip(boxes, cls_ids, confs, kpts_all): x1 = int(x_c - w/2); y1 = int(y_c - h/2) x2 = int(x_c + w/2); y2 = int(y_c + h/2) cv2.rectangle(img, (x1, y1), (x2, y2), (255,255,0), 2) text = f"{result.names[int(cls_id)]} {conf:.2f}" (tw, th), _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 1) cv2.rectangle(img, (x1, y1 - th - 4), (x1 + tw, y1), (255,255,0), cv2.FILLED) cv2.putText(img, text, (x1, y1 - 4), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,0,0), 1) for i, (x, y, v) in enumerate(kpts): if v > keypoint_threshold and i < len(labels): xi, yi = int(x), int(y) cv2.circle(img, (xi, yi), int(point_size), color_bgr, -1) if show_labels: cv2.putText(img, labels[i], (xi + 3, yi + 3), cv2.FONT_HERSHEY_SIMPLEX, label_size, (255,0,0), 2) return img def generate_xml(filename: str, width: int, height: int, keypoints: list[tuple[float, float, str]]): """Generate an XML annotation file for a single image.""" annotations = ET.Element("annotations") image_tag = ET.SubElement(annotations, "image", filename=filename, width=str(width), height=str(height)) for idx, (x, y, label) in enumerate(keypoints): ET.SubElement(image_tag, "point", id=str(idx), x=str(x), y=str(y), label=label) tree = ET.ElementTree(annotations) xml_filename = f"{filename}.xml" xml_path = ANNOTATIONS_DIR / xml_filename tree.write(str(xml_path), encoding="utf-8", xml_declaration=True) return xml_path def process_images(image_list, conf_threshold: float, keypoint_threshold: float, model_choice: str, show_labels: bool, point_size: int, point_color: str, label_size: float): """Process multiple images: annotate and generate XMLs, then package into ZIP.""" model_cfg = MODEL_CONFIGS[model_choice] model = load_model(model_cfg["path"]) labels = model_cfg["labels"] imgsz = model_cfg["imgsz"] output_images = [] xml_paths = [] for file_obj in image_list: # Determine path and original filename if isinstance(file_obj, dict): tmp_path = Path(file_obj['name']) orig_name = Path(file_obj['orig_name']).name else: tmp_path = Path(file_obj.name) orig_name = tmp_path.name img = Image.open(tmp_path) img_rgb = np.array(img.convert("RGB")) img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR) height, width = img_bgr.shape[:2] results = model(img_bgr, conf=conf_threshold, imgsz=imgsz) annotated = draw_detections(img_bgr, results, labels, keypoint_threshold, show_labels, point_size, point_color, label_size) output_images.append(Image.fromarray(cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB))) # Collect keypoints for XML keypoints = [] for res in results: for kpts in res.keypoints.data.cpu().numpy(): for i, (x, y, v) in enumerate(kpts): if v > keypoint_threshold: keypoints.append((float(x), float(y), labels[i] if i < len(labels) else f"kp{i}")) xml_path = generate_xml(orig_name, width, height, keypoints) xml_paths.append(xml_path) # Create ZIP of all XMLs zip_path = Path(tempfile.gettempdir()) / "xml_annotations.zip" with zipfile.ZipFile(zip_path, 'w') as zipf: for p in xml_paths: arcname = Path('annotations') / p.name zipf.write(str(p), arcname.as_posix()) xml_list_str = "\n".join(str(p) for p in xml_paths) return output_images, xml_list_str, str(zip_path) # Gradio Interface def main(): iface = gr.Interface( fn=process_images, inputs=[ gr.File(file_types=["image"], file_count="multiple", label="Upload Images"), gr.Slider(minimum=0.0, maximum=1.0, value=0.25, step=0.01, label="Confidence Threshold"), gr.Slider(minimum=0.0, maximum=1.0, value=0.5, step=0.01, label="Keypoint Visibility Threshold"), gr.Radio(choices=list(MODEL_CONFIGS.keys()), label="Select Model", value="Dry Season Form"), gr.Checkbox(label="Show Keypoint Labels", value=True), gr.Slider(minimum=1, maximum=20, value=8, step=1, label="Keypoint Size"), gr.ColorPicker(label="Keypoint Color", value="#00FF00"), gr.Slider(minimum=0.3, maximum=3.0, value=1.0, step=0.1, label="Keypoint Label Font Size") ], outputs=[ gr.Gallery(label="Detection Results"), gr.Textbox(label="Generated XML Paths"), gr.File(label="Download All XMLs as ZIP") ], title="🦋 Melanitis leda Landmark Batch Annotator", description="Upload multiple images. It annotates each with keypoints and packages XMLs in a ZIP archive.", allow_flagging="never" ) iface.launch() if __name__ == "__main__": main()