Theivaprakasham Hari commited on
Commit
c9da29e
·
1 Parent(s): b49bef5

First push

Browse files
app.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from PIL import Image
4
+ from ultralytics import YOLO
5
+ import cv2
6
+
7
+ # Model configurations
8
+ MODEL_CONFIGS = {
9
+ "Dry Season Form": {
10
+ "path": r"models\DSF_Mleda_250.pt",
11
+ "labels": [
12
+ "F1", "F2", "F3", "F4", "F5", "F6", "F7", "F8", "F9", "F10", "F11", "F12", "F13", "F14", "F15", "F16",
13
+ "H1", "H2", "H3", "H4", "H5", "H6", "H7", "H8", "H9", "H10", "H11", "H12", "H13", "H14", "H15", "H16", "H17",
14
+ "fs1", "fs2", "fs3", "fs4", "fs5",
15
+ "hs1", "hs2", "hs3", "hs4", "hs5", "hs6",
16
+ "sc1", "sc2",
17
+ "sex", "right", "left", "grey", "black", "white"
18
+ ],
19
+ "imgsz": 1280
20
+ },
21
+ "Wet Season Form": {
22
+ "path": r"models\WSF_Mleda_200.pt",
23
+ "labels": [
24
+ 'F12', 'F14', 'fs1', 'fs2', 'fs3', 'fs4', 'fs5',
25
+ 'H12', 'H14', 'hs1', 'hs2', 'hs3', 'hs4', 'hs5', 'hs6',
26
+ 'white', 'black', 'grey', 'sex', 'blue', 'green', 'red', 'sc1', 'sc2'
27
+ ],
28
+ "imgsz": 1280
29
+ },
30
+ "All Season Form": {
31
+ "path": r"models\DSF_WSF_Mleda_450.pt",
32
+ "labels": [
33
+ "F1", "F2", "F3", "F4", "F5", "F6", "F7", "F8", "F9", "F10", "F11", "F12", "F13", "F14", "F15", "F16",
34
+ "H1", "H2", "H3", "H4", "H5", "H6", "H7", "H8", "H9", "H10", "H11", "H12", "H13", "H14", "H15", "H16", "H17",
35
+ "fs1", "fs2", "fs3", "fs4", "fs5",
36
+ "hs1", "hs2", "hs3", "hs4", "hs5", "hs6",
37
+ "sc1", "sc2",
38
+ "sex", "right", "left", "grey", "black", "white"
39
+ ],
40
+ "imgsz": 1280
41
+ }
42
+ }
43
+
44
+ def hex_to_bgr(hex_color: str) -> tuple:
45
+ """Convert #RRGGBB hex color to BGR tuple."""
46
+ hex_color = hex_color.lstrip("#")
47
+ if len(hex_color) != 6:
48
+ return (0, 255, 0) # Default to green if invalid
49
+ r = int(hex_color[0:2], 16)
50
+ g = int(hex_color[2:4], 16)
51
+ b = int(hex_color[4:6], 16)
52
+ return (b, g, r)
53
+
54
+ def load_model(path):
55
+ """Load YOLO model from the given path."""
56
+ return YOLO(path)
57
+
58
+ def draw_detections(image: np.ndarray, results, labels, keypoint_threshold: float, show_labels: bool, point_size: int, point_color: str, label_size: float) -> np.ndarray:
59
+ img = image.copy()
60
+ color_bgr = hex_to_bgr(point_color)
61
+
62
+ for result in results:
63
+ boxes = result.boxes.xywh.cpu().numpy()
64
+ cls_ids = result.boxes.cls.int().cpu().numpy()
65
+ confs = result.boxes.conf.cpu().numpy()
66
+ kpts_all = result.keypoints.data.cpu().numpy()
67
+
68
+ for (x_c, y_c, w, h), cls_id, conf, kpts in zip(boxes, cls_ids, confs, kpts_all):
69
+ x1 = int(x_c - w/2); y1 = int(y_c - h/2)
70
+ x2 = int(x_c + w/2); y2 = int(y_c + h/2)
71
+
72
+ cv2.rectangle(img, (x1, y1), (x2, y2), color=(255,255,0), thickness=2)
73
+
74
+ class_name = result.names[int(cls_id)]
75
+ text = f"{class_name} {conf:.2f}"
76
+ (tw, th), _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 1)
77
+ cv2.rectangle(img, (x1, y1 - th - 4), (x1 + tw, y1), (255,255,0), cv2.FILLED)
78
+ cv2.putText(img, text, (x1, y1 - 4), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,0,0), 1, cv2.LINE_AA)
79
+
80
+ for i, (x, y, v) in enumerate(kpts):
81
+ if v > keypoint_threshold and i < len(labels):
82
+ xi, yi = int(x), int(y)
83
+ cv2.circle(img, (xi, yi), int(point_size), color_bgr, -1, cv2.LINE_AA)
84
+ if show_labels:
85
+ cv2.putText(img, labels[i], (xi + 3, yi + 3), cv2.FONT_HERSHEY_SIMPLEX, label_size, (255,0,0), 2, cv2.LINE_AA)
86
+ return img
87
+
88
+
89
+ def predict_and_annotate(input_image: Image.Image, conf_threshold: float, keypoint_threshold: float, model_choice: str, show_labels: bool, point_size: int, point_color: str, label_size: float):
90
+ config = MODEL_CONFIGS[model_choice]
91
+ model = load_model(config["path"])
92
+ labels = config["labels"]
93
+ imgsz = config["imgsz"]
94
+
95
+ img_rgb = np.array(input_image.convert("RGB"))
96
+ img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
97
+
98
+ results = model(img_bgr, conf=conf_threshold, imgsz=imgsz)
99
+ annotated_bgr = draw_detections(img_bgr, results, labels, keypoint_threshold, show_labels, point_size, point_color, label_size)
100
+ annotated_rgb = cv2.cvtColor(annotated_bgr, cv2.COLOR_BGR2RGB)
101
+ return Image.fromarray(annotated_rgb)
102
+
103
+
104
+ # Gradio Interface
105
+ app = gr.Interface(
106
+ fn=predict_and_annotate,
107
+ inputs=[
108
+ gr.Image(type="pil", label="Upload Image"),
109
+ gr.Slider(minimum=0.0, maximum=1.0, value=0.25, step=0.01, label="Confidence Threshold"),
110
+ gr.Slider(minimum=0.0, maximum=1.0, value=0.5, step=0.01, label="Keypoint Visibility Threshold"),
111
+ gr.Radio(choices=list(MODEL_CONFIGS.keys()), label="Select Model", value="Dry Season Form"),
112
+ gr.Checkbox(label="Show Keypoint Labels", value=True),
113
+ gr.Slider(minimum=1, maximum=20, value=8, step=0.1, label="Keypoint Size"),
114
+ gr.ColorPicker(label="Keypoint Color", value="#00FF00"),
115
+ gr.Slider(minimum=0.3, maximum=3.0, value=1.0, step=0.1, label="Keypoint Label Font Size")
116
+ ],
117
+ outputs=gr.Image(type="pil", label="Detection Result", format="png"),
118
+ title="🦋 Melanitis leda Landmark Identification",
119
+ description="Upload an image and select the model. Customize detection and keypoint display settings.",
120
+ flagging_mode="never"
121
+ )
122
+
123
+
124
+ if __name__ == "__main__":
125
+ app.launch()
models/DSF_Mleda_250.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5b51ea730cb1c03399bc7d1de9fe51ca6f531af0d29c3064de9638a4a8ec701c
3
+ size 56864045
models/DSF_WSF_Mleda_450.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:59cd20c486514ff18080413657d784f76c95937db5d5f6fa8cfae1f95b24951a
3
+ size 56674033
models/WSF_Mleda_200.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:33b8298a9121ae3ccf5b5d5aff000d6dddda6f6e1c3e35a09f1215b8f53dd47c
3
+ size 53621489
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ultralytics==8.3.88
2
+ torch==2.6.0
3
+ torchvision==0.21.0
4
+ pycocotools==2.0.7
5
+ PyYAML==6.0.1
6
+ scipy==1.15.2
7
+ gradio==5.29.0
8
+ numpy==1.26.4
9
+ opencv-python==4.9.0.80
10
+ psutil==5.9.8
11
+ py-cpuinfo==9.0.0
12
+ safetensors==0.4.3
13
+ pillow==11.1.0
14
+ opencv-python==4.11.0.86
15
+ opencv-python-headless==4.7.0.72