Theivaprakasham Hari
corrected
5a1000a
import gradio as gr
import numpy as np
from PIL import Image
from ultralytics import YOLO
import cv2
from pathlib import Path
import xml.etree.ElementTree as ET
import tempfile
import zipfile
# Base path for model files
path = Path(__file__).parent
# Model configurations
MODEL_CONFIGS = {
"Dry Season Form": {
"path": path / "models/DSF_Mleda_250.pt",
"labels": [
"F1", "F2", "F3", "F4", "F5", "F6", "F7", "F8", "F9", "F10", "F11", "F12", "F13", "F14", "F15", "F16",
"H1", "H2", "H3", "H4", "H5", "H6", "H7", "H8", "H9", "H10", "H11", "H12", "H13", "H14", "H15", "H16", "H17",
"fs1", "fs2", "fs3", "fs4", "fs5",
"hs1", "hs2", "hs3", "hs4", "hs5", "hs6",
"sc1", "sc2",
"sex", "right", "left", "grey", "black", "white"
],
"imgsz": 1280
},
"Wet Season Form": {
"path": path / "models/WSF_Mleda_200.pt",
"labels": [
'F12', 'F14', 'fs1', 'fs2', 'fs3', 'fs4', 'fs5',
'H12', 'H14', 'hs1', 'hs2', 'hs3', 'hs4', 'hs5', 'hs6',
'white', 'black', 'grey', 'sex', 'blue', 'green', 'red', 'sc1', 'sc2'
],
"imgsz": 1280
},
"All Season Form": {
"path": path / "models/DSF_WSF_Mleda_450.pt",
"labels": [
"F1", "F2", "F3", "F4", "F5", "F6", "F7", "F8", "F9", "F10", "F11", "F12", "F13", "F14", "F15", "F16",
"H1", "H2", "H3", "H4", "H5", "H6", "H7", "H8", "H9", "H10", "H11", "H12", "H13", "H14", "H15", "H16", "H17",
"fs1", "fs2", "fs3", "fs4", "fs5",
"hs1", "hs2", "hs3", "hs4", "hs5", "hs6",
"sc1", "sc2",
"sex", "right", "left", "grey", "black", "white"
],
"imgsz": 1280
}
}
# Directory for XML annotations
ANNOTATIONS_DIR = Path(tempfile.gettempdir()) / "annotations"
ANNOTATIONS_DIR.mkdir(parents=True, exist_ok=True)
def hex_to_bgr(hex_color: str) -> tuple:
"""Convert #RRGGBB hex color to BGR tuple."""
hex_color = hex_color.lstrip("#")
if len(hex_color) != 6:
return (0, 255, 0) # Default to green if invalid
r = int(hex_color[0:2], 16)
g = int(hex_color[2:4], 16)
b = int(hex_color[4:6], 16)
return (b, g, r)
def load_model(path: Path):
"""Load YOLO model from the given path."""
return YOLO(str(path))
def draw_detections(image: np.ndarray, results, labels, keypoint_threshold: float,
show_labels: bool, point_size: int, point_color: str,
label_size: float) -> np.ndarray:
"""Draw bounding boxes, keypoints, and labels on the image."""
img = image.copy()
color_bgr = hex_to_bgr(point_color)
for result in results:
boxes = result.boxes.xywh.cpu().numpy()
cls_ids = result.boxes.cls.int().cpu().numpy()
confs = result.boxes.conf.cpu().numpy()
kpts_all = result.keypoints.data.cpu().numpy()
for (x_c, y_c, w, h), cls_id, conf, kpts in zip(boxes, cls_ids, confs, kpts_all):
x1 = int(x_c - w/2); y1 = int(y_c - h/2)
x2 = int(x_c + w/2); y2 = int(y_c + h/2)
cv2.rectangle(img, (x1, y1), (x2, y2), (255,255,0), 2)
text = f"{result.names[int(cls_id)]} {conf:.2f}"
(tw, th), _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 1)
cv2.rectangle(img, (x1, y1 - th - 4), (x1 + tw, y1), (255,255,0), cv2.FILLED)
cv2.putText(img, text, (x1, y1 - 4), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,0,0), 1)
for i, (x, y, v) in enumerate(kpts):
if v > keypoint_threshold and i < len(labels):
xi, yi = int(x), int(y)
cv2.circle(img, (xi, yi), int(point_size), color_bgr, -1)
if show_labels:
cv2.putText(img, labels[i], (xi + 3, yi + 3), cv2.FONT_HERSHEY_SIMPLEX,
label_size, (255,0,0), 2)
return img
def generate_xml(filename: str, width: int, height: int, keypoints: list[tuple[float, float, str]]):
"""Generate an XML annotation file for a single image."""
annotations = ET.Element("annotations")
image_tag = ET.SubElement(annotations, "image", filename=filename,
width=str(width), height=str(height))
for idx, (x, y, label) in enumerate(keypoints):
ET.SubElement(image_tag, "point", id=str(idx), x=str(x), y=str(y), label=label)
tree = ET.ElementTree(annotations)
xml_filename = f"{filename}.xml"
xml_path = ANNOTATIONS_DIR / xml_filename
tree.write(str(xml_path), encoding="utf-8", xml_declaration=True)
return xml_path
def process_images(image_list, conf_threshold: float, keypoint_threshold: float,
model_choice: str, show_labels: bool, point_size: int,
point_color: str, label_size: float):
"""Process multiple images: annotate and generate XMLs, then package into ZIP."""
model_cfg = MODEL_CONFIGS[model_choice]
model = load_model(model_cfg["path"])
labels = model_cfg["labels"]
imgsz = model_cfg["imgsz"]
output_images = []
xml_paths = []
for file_obj in image_list:
# Determine path and original filename
if isinstance(file_obj, dict):
tmp_path = Path(file_obj['name'])
orig_name = Path(file_obj['orig_name']).name
else:
tmp_path = Path(file_obj.name)
orig_name = tmp_path.name
img = Image.open(tmp_path)
img_rgb = np.array(img.convert("RGB"))
img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
height, width = img_bgr.shape[:2]
results = model(img_bgr, conf=conf_threshold, imgsz=imgsz)
annotated = draw_detections(img_bgr, results, labels,
keypoint_threshold, show_labels,
point_size, point_color, label_size)
output_images.append(Image.fromarray(cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB)))
# Collect keypoints for XML
keypoints = []
for res in results:
for kpts in res.keypoints.data.cpu().numpy():
for i, (x, y, v) in enumerate(kpts):
if v > keypoint_threshold:
keypoints.append((float(x), float(y), labels[i] if i < len(labels) else f"kp{i}"))
xml_path = generate_xml(orig_name, width, height, keypoints)
xml_paths.append(xml_path)
# Create ZIP of all XMLs
zip_path = Path(tempfile.gettempdir()) / "xml_annotations.zip"
with zipfile.ZipFile(zip_path, 'w') as zipf:
for p in xml_paths:
arcname = Path('annotations') / p.name
zipf.write(str(p), arcname.as_posix())
xml_list_str = "\n".join(str(p) for p in xml_paths)
return output_images, xml_list_str, str(zip_path)
# Gradio Interface
def main():
iface = gr.Interface(
fn=process_images,
inputs=[
gr.File(file_types=["image"], file_count="multiple", label="Upload Images"),
gr.Slider(minimum=0.0, maximum=1.0, value=0.25, step=0.01, label="Confidence Threshold"),
gr.Slider(minimum=0.0, maximum=1.0, value=0.5, step=0.01, label="Keypoint Visibility Threshold"),
gr.Radio(choices=list(MODEL_CONFIGS.keys()), label="Select Model", value="Dry Season Form"),
gr.Checkbox(label="Show Keypoint Labels", value=True),
gr.Slider(minimum=1, maximum=20, value=8, step=1, label="Keypoint Size"),
gr.ColorPicker(label="Keypoint Color", value="#00FF00"),
gr.Slider(minimum=0.3, maximum=3.0, value=1.0, step=0.1, label="Keypoint Label Font Size")
],
outputs=[
gr.Gallery(label="Detection Results"),
gr.Textbox(label="Generated XML Paths"),
gr.File(label="Download All XMLs as ZIP")
],
title="🦋 Melanitis leda Landmark Batch Annotator",
description="Upload multiple images. It annotates each with keypoints and packages XMLs in a ZIP archive.",
allow_flagging="never"
)
iface.launch()
if __name__ == "__main__":
main()