Theivaprakasham Hari commited on
Commit
e71a089
·
1 Parent(s): 4edafab

first commit

Browse files
app.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from PIL import Image
4
+ from ultralytics import YOLO
5
+ import cv2
6
+ from pathlib import Path
7
+ import xml.etree.ElementTree as ET
8
+ import tempfile
9
+ import zipfile
10
+
11
+ # Base path for model files
12
+ path = Path(__file__).parent
13
+
14
+ # Model configurations
15
+ MODEL_CONFIGS = {
16
+ "Dry Season Form": {
17
+ "path": path / "models/DSF_Mleda_250.pt",
18
+ "labels": [
19
+ "F1", "F2", "F3", "F4", "F5", "F6", "F7", "F8", "F9", "F10", "F11", "F12", "F13", "F14", "F15", "F16",
20
+ "H1", "H2", "H3", "H4", "H5", "H6", "H7", "H8", "H9", "H10", "H11", "H12", "H13", "H14", "H15", "H16", "H17",
21
+ "fs1", "fs2", "fs3", "fs4", "fs5",
22
+ "hs1", "hs2", "hs3", "hs4", "hs5", "hs6",
23
+ "sc1", "sc2",
24
+ "sex", "right", "left", "grey", "black", "white"
25
+ ],
26
+ "imgsz": 1280
27
+ },
28
+ "Wet Season Form": {
29
+ "path": path / "models/WSF_Mleda_200.pt",
30
+ "labels": [
31
+ 'F12', 'F14', 'fs1', 'fs2', 'fs3', 'fs4', 'fs5',
32
+ 'H12', 'H14', 'hs1', 'hs2', 'hs3', 'hs4', 'hs5', 'hs6',
33
+ 'white', 'black', 'grey', 'sex', 'blue', 'green', 'red', 'sc1', 'sc2'
34
+ ],
35
+ "imgsz": 1280
36
+ },
37
+ "All Season Form": {
38
+ "path": path / "models/DSF_WSF_Mleda_450.pt",
39
+ "labels": [
40
+ "F1", "F2", "F3", "F4", "F5", "F6", "F7", "F8", "F9", "F10", "F11", "F12", "F13", "F14", "F15", "F16",
41
+ "H1", "H2", "H3", "H4", "H5", "H6", "H7", "H8", "H9", "H10", "H11", "H12", "H13", "H14", "H15", "H16", "H17",
42
+ "fs1", "fs2", "fs3", "fs4", "fs5",
43
+ "hs1", "hs2", "hs3", "hs4", "hs5", "hs6",
44
+ "sc1", "sc2",
45
+ "sex", "right", "left", "grey", "black", "white"
46
+ ],
47
+ "imgsz": 1280
48
+ }
49
+ }
50
+
51
+ # Directory for XML annotations
52
+ ANNOTATIONS_DIR = Path(tempfile.gettempdir()) / "annotations"
53
+ ANNOTATIONS_DIR.mkdir(parents=True, exist_ok=True)
54
+
55
+
56
+ def hex_to_bgr(hex_color: str) -> tuple:
57
+ """Convert #RRGGBB hex color to BGR tuple."""
58
+ hex_color = hex_color.lstrip("#")
59
+ if len(hex_color) != 6:
60
+ return (0, 255, 0) # Default to green if invalid
61
+ r = int(hex_color[0:2], 16)
62
+ g = int(hex_color[2:4], 16)
63
+ b = int(hex_color[4:6], 16)
64
+ return (b, g, r)
65
+
66
+
67
+ def load_model(path: Path):
68
+ """Load YOLO model from the given path."""
69
+ return YOLO(str(path))
70
+
71
+
72
+ def draw_detections(image: np.ndarray, results, labels, keypoint_threshold: float,
73
+ show_labels: bool, point_size: int, point_color: str,
74
+ label_size: float) -> np.ndarray:
75
+ """Draw bounding boxes, keypoints, and labels on the image."""
76
+ img = image.copy()
77
+ color_bgr = hex_to_bgr(point_color)
78
+
79
+ for result in results:
80
+ boxes = result.boxes.xywh.cpu().numpy()
81
+ cls_ids = result.boxes.cls.int().cpu().numpy()
82
+ confs = result.boxes.conf.cpu().numpy()
83
+ kpts_all = result.keypoints.data.cpu().numpy()
84
+
85
+ for (x_c, y_c, w, h), cls_id, conf, kpts in zip(boxes, cls_ids, confs, kpts_all):
86
+ x1 = int(x_c - w/2); y1 = int(y_c - h/2)
87
+ x2 = int(x_c + w/2); y2 = int(y_c + h/2)
88
+
89
+ cv2.rectangle(img, (x1, y1), (x2, y2), (255,255,0), 2)
90
+ text = f"{result.names[int(cls_id)]} {conf:.2f}"
91
+ (tw, th), _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 1)
92
+ cv2.rectangle(img, (x1, y1 - th - 4), (x1 + tw, y1), (255,255,0), cv2.FILLED)
93
+ cv2.putText(img, text, (x1, y1 - 4), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,0,0), 1)
94
+
95
+ for i, (x, y, v) in enumerate(kpts):
96
+ if v > keypoint_threshold and i < len(labels):
97
+ xi, yi = int(x), int(y)
98
+ cv2.circle(img, (xi, yi), int(point_size), color_bgr, -1)
99
+ if show_labels:
100
+ cv2.putText(img, labels[i], (xi + 3, yi + 3), cv2.FONT_HERSHEY_SIMPLEX,
101
+ label_size, (255,0,0), 2)
102
+ return img
103
+
104
+
105
+ def generate_xml(filename: str, width: int, height: int, keypoints: list[tuple[float, float, str]]):
106
+ """Generate an XML annotation file for a single image."""
107
+ base_name = Path(filename).stem
108
+ annotations = ET.Element("annotations")
109
+ image_tag = ET.SubElement(annotations, "image", filename=filename,
110
+ width=str(width), height=str(height))
111
+ for idx, (x, y, label) in enumerate(keypoints):
112
+ ET.SubElement(image_tag, "point", id=str(idx), x=str(x), y=str(y), label=label)
113
+ tree = ET.ElementTree(annotations)
114
+ xml_filename = f"{base_name}.xml"
115
+ xml_path = ANNOTATIONS_DIR / xml_filename
116
+ tree.write(str(xml_path), encoding="utf-8", xml_declaration=True)
117
+ return xml_path
118
+
119
+
120
+ def process_images(image_list, conf_threshold: float, keypoint_threshold: float,
121
+ model_choice: str, show_labels: bool, point_size: int,
122
+ point_color: str, label_size: float):
123
+ """Process multiple images: annotate and generate XMLs, then package into ZIP."""
124
+ model_cfg = MODEL_CONFIGS[model_choice]
125
+ model = load_model(model_cfg["path"])
126
+ labels = model_cfg["labels"]
127
+ imgsz = model_cfg["imgsz"]
128
+
129
+ output_images = []
130
+ xml_paths = []
131
+
132
+ for file_obj in image_list:
133
+ # Determine path and original filename
134
+ if isinstance(file_obj, dict):
135
+ tmp_path = Path(file_obj['name'])
136
+ orig_name = Path(file_obj['orig_name']).name
137
+ else:
138
+ tmp_path = Path(file_obj.name)
139
+ orig_name = tmp_path.name
140
+
141
+ img = Image.open(tmp_path)
142
+ img_rgb = np.array(img.convert("RGB"))
143
+ img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
144
+ height, width = img_bgr.shape[:2]
145
+
146
+ results = model(img_bgr, conf=conf_threshold, imgsz=imgsz)
147
+ annotated = draw_detections(img_bgr, results, labels,
148
+ keypoint_threshold, show_labels,
149
+ point_size, point_color, label_size)
150
+ output_images.append(Image.fromarray(cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB)))
151
+
152
+ # Collect keypoints for XML
153
+ keypoints = []
154
+ for res in results:
155
+ for kpts in res.keypoints.data.cpu().numpy():
156
+ for i, (x, y, v) in enumerate(kpts):
157
+ if v > keypoint_threshold:
158
+ keypoints.append((float(x), float(y), labels[i] if i < len(labels) else f"kp{i}"))
159
+
160
+ xml_path = generate_xml(orig_name, width, height, keypoints)
161
+ xml_paths.append(xml_path)
162
+
163
+ # Create ZIP of all XMLs
164
+ zip_path = Path(tempfile.gettempdir()) / "xml_annotations.zip"
165
+ with zipfile.ZipFile(zip_path, 'w') as zipf:
166
+ for p in xml_paths:
167
+ arcname = Path('annotations') / p.name
168
+ zipf.write(str(p), arcname.as_posix())
169
+
170
+ xml_list_str = "\n".join(str(p) for p in xml_paths)
171
+ return output_images, xml_list_str, str(zip_path)
172
+
173
+ # Gradio Interface
174
+ def main():
175
+ iface = gr.Interface(
176
+ fn=process_images,
177
+ inputs=[
178
+ gr.File(file_types=["image"], file_count="multiple", label="Upload Images"),
179
+ gr.Slider(minimum=0.0, maximum=1.0, value=0.25, step=0.01, label="Confidence Threshold"),
180
+ gr.Slider(minimum=0.0, maximum=1.0, value=0.5, step=0.01, label="Keypoint Visibility Threshold"),
181
+ gr.Radio(choices=list(MODEL_CONFIGS.keys()), label="Select Model", value="Dry Season Form"),
182
+ gr.Checkbox(label="Show Keypoint Labels", value=True),
183
+ gr.Slider(minimum=1, maximum=20, value=8, step=1, label="Keypoint Size"),
184
+ gr.ColorPicker(label="Keypoint Color", value="#00FF00"),
185
+ gr.Slider(minimum=0.3, maximum=3.0, value=1.0, step=0.1, label="Keypoint Label Font Size")
186
+ ],
187
+ outputs=[
188
+ gr.Gallery(label="Detection Results"),
189
+ gr.Textbox(label="Generated XML Paths"),
190
+ gr.File(label="Download All XMLs as ZIP")
191
+ ],
192
+ title="🦋 Melanitis leda Landmark Batch Annotator",
193
+ description="Upload multiple images. It annotates each with keypoints and packages XMLs in a ZIP archive.",
194
+ allow_flagging="never"
195
+ )
196
+ iface.launch()
197
+
198
+ if __name__ == "__main__":
199
+ main()
models/DSF_Mleda_250.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5b51ea730cb1c03399bc7d1de9fe51ca6f531af0d29c3064de9638a4a8ec701c
3
+ size 56864045
models/DSF_WSF_Mleda_450.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:59cd20c486514ff18080413657d784f76c95937db5d5f6fa8cfae1f95b24951a
3
+ size 56674033
models/WSF_Mleda_200.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:33b8298a9121ae3ccf5b5d5aff000d6dddda6f6e1c3e35a09f1215b8f53dd47c
3
+ size 53621489
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ultralytics==8.3.88
2
+ torch==2.6.0
3
+ torchvision==0.21.0
4
+ pycocotools==2.0.7
5
+ PyYAML==6.0.1
6
+ scipy==1.15.2
7
+ gradio==5.29.0
8
+ numpy==1.26.4
9
+ psutil==5.9.8
10
+ py-cpuinfo==9.0.0
11
+ safetensors==0.4.3
12
+ pillow==11.1.0
13
+ opencv-python==4.11.0.86
14
+ opencv-python-headless==4.7.0.72