|
import gradio as gr |
|
import numpy as np |
|
from PIL import Image |
|
from ultralytics import YOLO |
|
import cv2 |
|
from pathlib import Path |
|
import xml.etree.ElementTree as ET |
|
import tempfile |
|
import zipfile |
|
|
|
|
|
path = Path(__file__).parent |
|
|
|
|
|
MODEL_CONFIGS = { |
|
"Dry Season Form": { |
|
"path": path / "models/DSF_Mleda_250.pt", |
|
"labels": [ |
|
"F1", "F2", "F3", "F4", "F5", "F6", "F7", "F8", "F9", "F10", "F11", "F12", "F13", "F14", "F15", "F16", |
|
"H1", "H2", "H3", "H4", "H5", "H6", "H7", "H8", "H9", "H10", "H11", "H12", "H13", "H14", "H15", "H16", "H17", |
|
"fs1", "fs2", "fs3", "fs4", "fs5", |
|
"hs1", "hs2", "hs3", "hs4", "hs5", "hs6", |
|
"sc1", "sc2", |
|
"sex", "right", "left", "grey", "black", "white" |
|
], |
|
"imgsz": 1280 |
|
}, |
|
"Wet Season Form": { |
|
"path": path / "models/WSF_Mleda_200.pt", |
|
"labels": [ |
|
'F12', 'F14', 'fs1', 'fs2', 'fs3', 'fs4', 'fs5', |
|
'H12', 'H14', 'hs1', 'hs2', 'hs3', 'hs4', 'hs5', 'hs6', |
|
'white', 'black', 'grey', 'sex', 'blue', 'green', 'red', 'sc1', 'sc2' |
|
], |
|
"imgsz": 1280 |
|
}, |
|
"All Season Form": { |
|
"path": path / "models/DSF_WSF_Mleda_450.pt", |
|
"labels": [ |
|
"F1", "F2", "F3", "F4", "F5", "F6", "F7", "F8", "F9", "F10", "F11", "F12", "F13", "F14", "F15", "F16", |
|
"H1", "H2", "H3", "H4", "H5", "H6", "H7", "H8", "H9", "H10", "H11", "H12", "H13", "H14", "H15", "H16", "H17", |
|
"fs1", "fs2", "fs3", "fs4", "fs5", |
|
"hs1", "hs2", "hs3", "hs4", "hs5", "hs6", |
|
"sc1", "sc2", |
|
"sex", "right", "left", "grey", "black", "white" |
|
], |
|
"imgsz": 1280 |
|
} |
|
} |
|
|
|
|
|
ANNOTATIONS_DIR = Path(tempfile.gettempdir()) / "annotations" |
|
ANNOTATIONS_DIR.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
def hex_to_bgr(hex_color: str) -> tuple: |
|
"""Convert #RRGGBB hex color to BGR tuple.""" |
|
hex_color = hex_color.lstrip("#") |
|
if len(hex_color) != 6: |
|
return (0, 255, 0) |
|
r = int(hex_color[0:2], 16) |
|
g = int(hex_color[2:4], 16) |
|
b = int(hex_color[4:6], 16) |
|
return (b, g, r) |
|
|
|
|
|
def load_model(path: Path): |
|
"""Load YOLO model from the given path.""" |
|
return YOLO(str(path)) |
|
|
|
|
|
def draw_detections(image: np.ndarray, results, labels, keypoint_threshold: float, |
|
show_labels: bool, point_size: int, point_color: str, |
|
label_size: float) -> np.ndarray: |
|
"""Draw bounding boxes, keypoints, and labels on the image.""" |
|
img = image.copy() |
|
color_bgr = hex_to_bgr(point_color) |
|
|
|
for result in results: |
|
boxes = result.boxes.xywh.cpu().numpy() |
|
cls_ids = result.boxes.cls.int().cpu().numpy() |
|
confs = result.boxes.conf.cpu().numpy() |
|
kpts_all = result.keypoints.data.cpu().numpy() |
|
|
|
for (x_c, y_c, w, h), cls_id, conf, kpts in zip(boxes, cls_ids, confs, kpts_all): |
|
x1 = int(x_c - w/2); y1 = int(y_c - h/2) |
|
x2 = int(x_c + w/2); y2 = int(y_c + h/2) |
|
|
|
cv2.rectangle(img, (x1, y1), (x2, y2), (255,255,0), 2) |
|
text = f"{result.names[int(cls_id)]} {conf:.2f}" |
|
(tw, th), _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 1) |
|
cv2.rectangle(img, (x1, y1 - th - 4), (x1 + tw, y1), (255,255,0), cv2.FILLED) |
|
cv2.putText(img, text, (x1, y1 - 4), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,0,0), 1) |
|
|
|
for i, (x, y, v) in enumerate(kpts): |
|
if v > keypoint_threshold and i < len(labels): |
|
xi, yi = int(x), int(y) |
|
cv2.circle(img, (xi, yi), int(point_size), color_bgr, -1) |
|
if show_labels: |
|
cv2.putText(img, labels[i], (xi + 3, yi + 3), cv2.FONT_HERSHEY_SIMPLEX, |
|
label_size, (255,0,0), 2) |
|
return img |
|
|
|
|
|
def generate_xml(filename: str, width: int, height: int, keypoints: list[tuple[float, float, str]]): |
|
"""Generate an XML annotation file for a single image.""" |
|
annotations = ET.Element("annotations") |
|
image_tag = ET.SubElement(annotations, "image", filename=filename, |
|
width=str(width), height=str(height)) |
|
for idx, (x, y, label) in enumerate(keypoints): |
|
ET.SubElement(image_tag, "point", id=str(idx), x=str(x), y=str(y), label=label) |
|
tree = ET.ElementTree(annotations) |
|
xml_filename = f"{filename}.xml" |
|
xml_path = ANNOTATIONS_DIR / xml_filename |
|
tree.write(str(xml_path), encoding="utf-8", xml_declaration=True) |
|
return xml_path |
|
|
|
|
|
def process_images(image_list, conf_threshold: float, keypoint_threshold: float, |
|
model_choice: str, show_labels: bool, point_size: int, |
|
point_color: str, label_size: float): |
|
"""Process multiple images: annotate and generate XMLs, then package into ZIP.""" |
|
model_cfg = MODEL_CONFIGS[model_choice] |
|
model = load_model(model_cfg["path"]) |
|
labels = model_cfg["labels"] |
|
imgsz = model_cfg["imgsz"] |
|
|
|
output_images = [] |
|
xml_paths = [] |
|
|
|
for file_obj in image_list: |
|
|
|
if isinstance(file_obj, dict): |
|
tmp_path = Path(file_obj['name']) |
|
orig_name = Path(file_obj['orig_name']).name |
|
else: |
|
tmp_path = Path(file_obj.name) |
|
orig_name = tmp_path.name |
|
|
|
img = Image.open(tmp_path) |
|
img_rgb = np.array(img.convert("RGB")) |
|
img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR) |
|
height, width = img_bgr.shape[:2] |
|
|
|
results = model(img_bgr, conf=conf_threshold, imgsz=imgsz) |
|
annotated = draw_detections(img_bgr, results, labels, |
|
keypoint_threshold, show_labels, |
|
point_size, point_color, label_size) |
|
output_images.append(Image.fromarray(cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB))) |
|
|
|
|
|
keypoints = [] |
|
for res in results: |
|
for kpts in res.keypoints.data.cpu().numpy(): |
|
for i, (x, y, v) in enumerate(kpts): |
|
if v > keypoint_threshold: |
|
keypoints.append((float(x), float(y), labels[i] if i < len(labels) else f"kp{i}")) |
|
|
|
xml_path = generate_xml(orig_name, width, height, keypoints) |
|
xml_paths.append(xml_path) |
|
|
|
|
|
zip_path = Path(tempfile.gettempdir()) / "xml_annotations.zip" |
|
with zipfile.ZipFile(zip_path, 'w') as zipf: |
|
for p in xml_paths: |
|
arcname = Path('annotations') / p.name |
|
zipf.write(str(p), arcname.as_posix()) |
|
|
|
xml_list_str = "\n".join(str(p) for p in xml_paths) |
|
return output_images, xml_list_str, str(zip_path) |
|
|
|
|
|
def main(): |
|
iface = gr.Interface( |
|
fn=process_images, |
|
inputs=[ |
|
gr.File(file_types=["image"], file_count="multiple", label="Upload Images"), |
|
gr.Slider(minimum=0.0, maximum=1.0, value=0.25, step=0.01, label="Confidence Threshold"), |
|
gr.Slider(minimum=0.0, maximum=1.0, value=0.5, step=0.01, label="Keypoint Visibility Threshold"), |
|
gr.Radio(choices=list(MODEL_CONFIGS.keys()), label="Select Model", value="Dry Season Form"), |
|
gr.Checkbox(label="Show Keypoint Labels", value=True), |
|
gr.Slider(minimum=1, maximum=20, value=8, step=1, label="Keypoint Size"), |
|
gr.ColorPicker(label="Keypoint Color", value="#00FF00"), |
|
gr.Slider(minimum=0.3, maximum=3.0, value=1.0, step=0.1, label="Keypoint Label Font Size") |
|
], |
|
outputs=[ |
|
gr.Gallery(label="Detection Results"), |
|
gr.Textbox(label="Generated XML Paths"), |
|
gr.File(label="Download All XMLs as ZIP") |
|
], |
|
title="🦋 Melanitis leda Landmark Batch Annotator", |
|
description="Upload multiple images. It annotates each with keypoints and packages XMLs in a ZIP archive.", |
|
allow_flagging="never" |
|
) |
|
iface.launch() |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|