Theivaprakasham Hari
commited on
Commit
·
e71a089
1
Parent(s):
4edafab
first commit
Browse files- app.py +199 -0
- models/DSF_Mleda_250.pt +3 -0
- models/DSF_WSF_Mleda_450.pt +3 -0
- models/WSF_Mleda_200.pt +3 -0
- requirements.txt +14 -0
app.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
from ultralytics import YOLO
|
5 |
+
import cv2
|
6 |
+
from pathlib import Path
|
7 |
+
import xml.etree.ElementTree as ET
|
8 |
+
import tempfile
|
9 |
+
import zipfile
|
10 |
+
|
11 |
+
# Base path for model files
|
12 |
+
path = Path(__file__).parent
|
13 |
+
|
14 |
+
# Model configurations
|
15 |
+
MODEL_CONFIGS = {
|
16 |
+
"Dry Season Form": {
|
17 |
+
"path": path / "models/DSF_Mleda_250.pt",
|
18 |
+
"labels": [
|
19 |
+
"F1", "F2", "F3", "F4", "F5", "F6", "F7", "F8", "F9", "F10", "F11", "F12", "F13", "F14", "F15", "F16",
|
20 |
+
"H1", "H2", "H3", "H4", "H5", "H6", "H7", "H8", "H9", "H10", "H11", "H12", "H13", "H14", "H15", "H16", "H17",
|
21 |
+
"fs1", "fs2", "fs3", "fs4", "fs5",
|
22 |
+
"hs1", "hs2", "hs3", "hs4", "hs5", "hs6",
|
23 |
+
"sc1", "sc2",
|
24 |
+
"sex", "right", "left", "grey", "black", "white"
|
25 |
+
],
|
26 |
+
"imgsz": 1280
|
27 |
+
},
|
28 |
+
"Wet Season Form": {
|
29 |
+
"path": path / "models/WSF_Mleda_200.pt",
|
30 |
+
"labels": [
|
31 |
+
'F12', 'F14', 'fs1', 'fs2', 'fs3', 'fs4', 'fs5',
|
32 |
+
'H12', 'H14', 'hs1', 'hs2', 'hs3', 'hs4', 'hs5', 'hs6',
|
33 |
+
'white', 'black', 'grey', 'sex', 'blue', 'green', 'red', 'sc1', 'sc2'
|
34 |
+
],
|
35 |
+
"imgsz": 1280
|
36 |
+
},
|
37 |
+
"All Season Form": {
|
38 |
+
"path": path / "models/DSF_WSF_Mleda_450.pt",
|
39 |
+
"labels": [
|
40 |
+
"F1", "F2", "F3", "F4", "F5", "F6", "F7", "F8", "F9", "F10", "F11", "F12", "F13", "F14", "F15", "F16",
|
41 |
+
"H1", "H2", "H3", "H4", "H5", "H6", "H7", "H8", "H9", "H10", "H11", "H12", "H13", "H14", "H15", "H16", "H17",
|
42 |
+
"fs1", "fs2", "fs3", "fs4", "fs5",
|
43 |
+
"hs1", "hs2", "hs3", "hs4", "hs5", "hs6",
|
44 |
+
"sc1", "sc2",
|
45 |
+
"sex", "right", "left", "grey", "black", "white"
|
46 |
+
],
|
47 |
+
"imgsz": 1280
|
48 |
+
}
|
49 |
+
}
|
50 |
+
|
51 |
+
# Directory for XML annotations
|
52 |
+
ANNOTATIONS_DIR = Path(tempfile.gettempdir()) / "annotations"
|
53 |
+
ANNOTATIONS_DIR.mkdir(parents=True, exist_ok=True)
|
54 |
+
|
55 |
+
|
56 |
+
def hex_to_bgr(hex_color: str) -> tuple:
|
57 |
+
"""Convert #RRGGBB hex color to BGR tuple."""
|
58 |
+
hex_color = hex_color.lstrip("#")
|
59 |
+
if len(hex_color) != 6:
|
60 |
+
return (0, 255, 0) # Default to green if invalid
|
61 |
+
r = int(hex_color[0:2], 16)
|
62 |
+
g = int(hex_color[2:4], 16)
|
63 |
+
b = int(hex_color[4:6], 16)
|
64 |
+
return (b, g, r)
|
65 |
+
|
66 |
+
|
67 |
+
def load_model(path: Path):
|
68 |
+
"""Load YOLO model from the given path."""
|
69 |
+
return YOLO(str(path))
|
70 |
+
|
71 |
+
|
72 |
+
def draw_detections(image: np.ndarray, results, labels, keypoint_threshold: float,
|
73 |
+
show_labels: bool, point_size: int, point_color: str,
|
74 |
+
label_size: float) -> np.ndarray:
|
75 |
+
"""Draw bounding boxes, keypoints, and labels on the image."""
|
76 |
+
img = image.copy()
|
77 |
+
color_bgr = hex_to_bgr(point_color)
|
78 |
+
|
79 |
+
for result in results:
|
80 |
+
boxes = result.boxes.xywh.cpu().numpy()
|
81 |
+
cls_ids = result.boxes.cls.int().cpu().numpy()
|
82 |
+
confs = result.boxes.conf.cpu().numpy()
|
83 |
+
kpts_all = result.keypoints.data.cpu().numpy()
|
84 |
+
|
85 |
+
for (x_c, y_c, w, h), cls_id, conf, kpts in zip(boxes, cls_ids, confs, kpts_all):
|
86 |
+
x1 = int(x_c - w/2); y1 = int(y_c - h/2)
|
87 |
+
x2 = int(x_c + w/2); y2 = int(y_c + h/2)
|
88 |
+
|
89 |
+
cv2.rectangle(img, (x1, y1), (x2, y2), (255,255,0), 2)
|
90 |
+
text = f"{result.names[int(cls_id)]} {conf:.2f}"
|
91 |
+
(tw, th), _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 1)
|
92 |
+
cv2.rectangle(img, (x1, y1 - th - 4), (x1 + tw, y1), (255,255,0), cv2.FILLED)
|
93 |
+
cv2.putText(img, text, (x1, y1 - 4), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,0,0), 1)
|
94 |
+
|
95 |
+
for i, (x, y, v) in enumerate(kpts):
|
96 |
+
if v > keypoint_threshold and i < len(labels):
|
97 |
+
xi, yi = int(x), int(y)
|
98 |
+
cv2.circle(img, (xi, yi), int(point_size), color_bgr, -1)
|
99 |
+
if show_labels:
|
100 |
+
cv2.putText(img, labels[i], (xi + 3, yi + 3), cv2.FONT_HERSHEY_SIMPLEX,
|
101 |
+
label_size, (255,0,0), 2)
|
102 |
+
return img
|
103 |
+
|
104 |
+
|
105 |
+
def generate_xml(filename: str, width: int, height: int, keypoints: list[tuple[float, float, str]]):
|
106 |
+
"""Generate an XML annotation file for a single image."""
|
107 |
+
base_name = Path(filename).stem
|
108 |
+
annotations = ET.Element("annotations")
|
109 |
+
image_tag = ET.SubElement(annotations, "image", filename=filename,
|
110 |
+
width=str(width), height=str(height))
|
111 |
+
for idx, (x, y, label) in enumerate(keypoints):
|
112 |
+
ET.SubElement(image_tag, "point", id=str(idx), x=str(x), y=str(y), label=label)
|
113 |
+
tree = ET.ElementTree(annotations)
|
114 |
+
xml_filename = f"{base_name}.xml"
|
115 |
+
xml_path = ANNOTATIONS_DIR / xml_filename
|
116 |
+
tree.write(str(xml_path), encoding="utf-8", xml_declaration=True)
|
117 |
+
return xml_path
|
118 |
+
|
119 |
+
|
120 |
+
def process_images(image_list, conf_threshold: float, keypoint_threshold: float,
|
121 |
+
model_choice: str, show_labels: bool, point_size: int,
|
122 |
+
point_color: str, label_size: float):
|
123 |
+
"""Process multiple images: annotate and generate XMLs, then package into ZIP."""
|
124 |
+
model_cfg = MODEL_CONFIGS[model_choice]
|
125 |
+
model = load_model(model_cfg["path"])
|
126 |
+
labels = model_cfg["labels"]
|
127 |
+
imgsz = model_cfg["imgsz"]
|
128 |
+
|
129 |
+
output_images = []
|
130 |
+
xml_paths = []
|
131 |
+
|
132 |
+
for file_obj in image_list:
|
133 |
+
# Determine path and original filename
|
134 |
+
if isinstance(file_obj, dict):
|
135 |
+
tmp_path = Path(file_obj['name'])
|
136 |
+
orig_name = Path(file_obj['orig_name']).name
|
137 |
+
else:
|
138 |
+
tmp_path = Path(file_obj.name)
|
139 |
+
orig_name = tmp_path.name
|
140 |
+
|
141 |
+
img = Image.open(tmp_path)
|
142 |
+
img_rgb = np.array(img.convert("RGB"))
|
143 |
+
img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
|
144 |
+
height, width = img_bgr.shape[:2]
|
145 |
+
|
146 |
+
results = model(img_bgr, conf=conf_threshold, imgsz=imgsz)
|
147 |
+
annotated = draw_detections(img_bgr, results, labels,
|
148 |
+
keypoint_threshold, show_labels,
|
149 |
+
point_size, point_color, label_size)
|
150 |
+
output_images.append(Image.fromarray(cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB)))
|
151 |
+
|
152 |
+
# Collect keypoints for XML
|
153 |
+
keypoints = []
|
154 |
+
for res in results:
|
155 |
+
for kpts in res.keypoints.data.cpu().numpy():
|
156 |
+
for i, (x, y, v) in enumerate(kpts):
|
157 |
+
if v > keypoint_threshold:
|
158 |
+
keypoints.append((float(x), float(y), labels[i] if i < len(labels) else f"kp{i}"))
|
159 |
+
|
160 |
+
xml_path = generate_xml(orig_name, width, height, keypoints)
|
161 |
+
xml_paths.append(xml_path)
|
162 |
+
|
163 |
+
# Create ZIP of all XMLs
|
164 |
+
zip_path = Path(tempfile.gettempdir()) / "xml_annotations.zip"
|
165 |
+
with zipfile.ZipFile(zip_path, 'w') as zipf:
|
166 |
+
for p in xml_paths:
|
167 |
+
arcname = Path('annotations') / p.name
|
168 |
+
zipf.write(str(p), arcname.as_posix())
|
169 |
+
|
170 |
+
xml_list_str = "\n".join(str(p) for p in xml_paths)
|
171 |
+
return output_images, xml_list_str, str(zip_path)
|
172 |
+
|
173 |
+
# Gradio Interface
|
174 |
+
def main():
|
175 |
+
iface = gr.Interface(
|
176 |
+
fn=process_images,
|
177 |
+
inputs=[
|
178 |
+
gr.File(file_types=["image"], file_count="multiple", label="Upload Images"),
|
179 |
+
gr.Slider(minimum=0.0, maximum=1.0, value=0.25, step=0.01, label="Confidence Threshold"),
|
180 |
+
gr.Slider(minimum=0.0, maximum=1.0, value=0.5, step=0.01, label="Keypoint Visibility Threshold"),
|
181 |
+
gr.Radio(choices=list(MODEL_CONFIGS.keys()), label="Select Model", value="Dry Season Form"),
|
182 |
+
gr.Checkbox(label="Show Keypoint Labels", value=True),
|
183 |
+
gr.Slider(minimum=1, maximum=20, value=8, step=1, label="Keypoint Size"),
|
184 |
+
gr.ColorPicker(label="Keypoint Color", value="#00FF00"),
|
185 |
+
gr.Slider(minimum=0.3, maximum=3.0, value=1.0, step=0.1, label="Keypoint Label Font Size")
|
186 |
+
],
|
187 |
+
outputs=[
|
188 |
+
gr.Gallery(label="Detection Results"),
|
189 |
+
gr.Textbox(label="Generated XML Paths"),
|
190 |
+
gr.File(label="Download All XMLs as ZIP")
|
191 |
+
],
|
192 |
+
title="🦋 Melanitis leda Landmark Batch Annotator",
|
193 |
+
description="Upload multiple images. It annotates each with keypoints and packages XMLs in a ZIP archive.",
|
194 |
+
allow_flagging="never"
|
195 |
+
)
|
196 |
+
iface.launch()
|
197 |
+
|
198 |
+
if __name__ == "__main__":
|
199 |
+
main()
|
models/DSF_Mleda_250.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5b51ea730cb1c03399bc7d1de9fe51ca6f531af0d29c3064de9638a4a8ec701c
|
3 |
+
size 56864045
|
models/DSF_WSF_Mleda_450.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:59cd20c486514ff18080413657d784f76c95937db5d5f6fa8cfae1f95b24951a
|
3 |
+
size 56674033
|
models/WSF_Mleda_200.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:33b8298a9121ae3ccf5b5d5aff000d6dddda6f6e1c3e35a09f1215b8f53dd47c
|
3 |
+
size 53621489
|
requirements.txt
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ultralytics==8.3.88
|
2 |
+
torch==2.6.0
|
3 |
+
torchvision==0.21.0
|
4 |
+
pycocotools==2.0.7
|
5 |
+
PyYAML==6.0.1
|
6 |
+
scipy==1.15.2
|
7 |
+
gradio==5.29.0
|
8 |
+
numpy==1.26.4
|
9 |
+
psutil==5.9.8
|
10 |
+
py-cpuinfo==9.0.0
|
11 |
+
safetensors==0.4.3
|
12 |
+
pillow==11.1.0
|
13 |
+
opencv-python==4.11.0.86
|
14 |
+
opencv-python-headless==4.7.0.72
|