Spaces:
Sleeping
Sleeping
Theivaprakasham Hari
commited on
Commit
·
c9da29e
1
Parent(s):
b49bef5
First push
Browse files- app.py +125 -0
- models/DSF_Mleda_250.pt +3 -0
- models/DSF_WSF_Mleda_450.pt +3 -0
- models/WSF_Mleda_200.pt +3 -0
- requirements.txt +15 -0
app.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
from ultralytics import YOLO
|
5 |
+
import cv2
|
6 |
+
|
7 |
+
# Model configurations
|
8 |
+
MODEL_CONFIGS = {
|
9 |
+
"Dry Season Form": {
|
10 |
+
"path": r"models\DSF_Mleda_250.pt",
|
11 |
+
"labels": [
|
12 |
+
"F1", "F2", "F3", "F4", "F5", "F6", "F7", "F8", "F9", "F10", "F11", "F12", "F13", "F14", "F15", "F16",
|
13 |
+
"H1", "H2", "H3", "H4", "H5", "H6", "H7", "H8", "H9", "H10", "H11", "H12", "H13", "H14", "H15", "H16", "H17",
|
14 |
+
"fs1", "fs2", "fs3", "fs4", "fs5",
|
15 |
+
"hs1", "hs2", "hs3", "hs4", "hs5", "hs6",
|
16 |
+
"sc1", "sc2",
|
17 |
+
"sex", "right", "left", "grey", "black", "white"
|
18 |
+
],
|
19 |
+
"imgsz": 1280
|
20 |
+
},
|
21 |
+
"Wet Season Form": {
|
22 |
+
"path": r"models\WSF_Mleda_200.pt",
|
23 |
+
"labels": [
|
24 |
+
'F12', 'F14', 'fs1', 'fs2', 'fs3', 'fs4', 'fs5',
|
25 |
+
'H12', 'H14', 'hs1', 'hs2', 'hs3', 'hs4', 'hs5', 'hs6',
|
26 |
+
'white', 'black', 'grey', 'sex', 'blue', 'green', 'red', 'sc1', 'sc2'
|
27 |
+
],
|
28 |
+
"imgsz": 1280
|
29 |
+
},
|
30 |
+
"All Season Form": {
|
31 |
+
"path": r"models\DSF_WSF_Mleda_450.pt",
|
32 |
+
"labels": [
|
33 |
+
"F1", "F2", "F3", "F4", "F5", "F6", "F7", "F8", "F9", "F10", "F11", "F12", "F13", "F14", "F15", "F16",
|
34 |
+
"H1", "H2", "H3", "H4", "H5", "H6", "H7", "H8", "H9", "H10", "H11", "H12", "H13", "H14", "H15", "H16", "H17",
|
35 |
+
"fs1", "fs2", "fs3", "fs4", "fs5",
|
36 |
+
"hs1", "hs2", "hs3", "hs4", "hs5", "hs6",
|
37 |
+
"sc1", "sc2",
|
38 |
+
"sex", "right", "left", "grey", "black", "white"
|
39 |
+
],
|
40 |
+
"imgsz": 1280
|
41 |
+
}
|
42 |
+
}
|
43 |
+
|
44 |
+
def hex_to_bgr(hex_color: str) -> tuple:
|
45 |
+
"""Convert #RRGGBB hex color to BGR tuple."""
|
46 |
+
hex_color = hex_color.lstrip("#")
|
47 |
+
if len(hex_color) != 6:
|
48 |
+
return (0, 255, 0) # Default to green if invalid
|
49 |
+
r = int(hex_color[0:2], 16)
|
50 |
+
g = int(hex_color[2:4], 16)
|
51 |
+
b = int(hex_color[4:6], 16)
|
52 |
+
return (b, g, r)
|
53 |
+
|
54 |
+
def load_model(path):
|
55 |
+
"""Load YOLO model from the given path."""
|
56 |
+
return YOLO(path)
|
57 |
+
|
58 |
+
def draw_detections(image: np.ndarray, results, labels, keypoint_threshold: float, show_labels: bool, point_size: int, point_color: str, label_size: float) -> np.ndarray:
|
59 |
+
img = image.copy()
|
60 |
+
color_bgr = hex_to_bgr(point_color)
|
61 |
+
|
62 |
+
for result in results:
|
63 |
+
boxes = result.boxes.xywh.cpu().numpy()
|
64 |
+
cls_ids = result.boxes.cls.int().cpu().numpy()
|
65 |
+
confs = result.boxes.conf.cpu().numpy()
|
66 |
+
kpts_all = result.keypoints.data.cpu().numpy()
|
67 |
+
|
68 |
+
for (x_c, y_c, w, h), cls_id, conf, kpts in zip(boxes, cls_ids, confs, kpts_all):
|
69 |
+
x1 = int(x_c - w/2); y1 = int(y_c - h/2)
|
70 |
+
x2 = int(x_c + w/2); y2 = int(y_c + h/2)
|
71 |
+
|
72 |
+
cv2.rectangle(img, (x1, y1), (x2, y2), color=(255,255,0), thickness=2)
|
73 |
+
|
74 |
+
class_name = result.names[int(cls_id)]
|
75 |
+
text = f"{class_name} {conf:.2f}"
|
76 |
+
(tw, th), _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 1)
|
77 |
+
cv2.rectangle(img, (x1, y1 - th - 4), (x1 + tw, y1), (255,255,0), cv2.FILLED)
|
78 |
+
cv2.putText(img, text, (x1, y1 - 4), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,0,0), 1, cv2.LINE_AA)
|
79 |
+
|
80 |
+
for i, (x, y, v) in enumerate(kpts):
|
81 |
+
if v > keypoint_threshold and i < len(labels):
|
82 |
+
xi, yi = int(x), int(y)
|
83 |
+
cv2.circle(img, (xi, yi), int(point_size), color_bgr, -1, cv2.LINE_AA)
|
84 |
+
if show_labels:
|
85 |
+
cv2.putText(img, labels[i], (xi + 3, yi + 3), cv2.FONT_HERSHEY_SIMPLEX, label_size, (255,0,0), 2, cv2.LINE_AA)
|
86 |
+
return img
|
87 |
+
|
88 |
+
|
89 |
+
def predict_and_annotate(input_image: Image.Image, conf_threshold: float, keypoint_threshold: float, model_choice: str, show_labels: bool, point_size: int, point_color: str, label_size: float):
|
90 |
+
config = MODEL_CONFIGS[model_choice]
|
91 |
+
model = load_model(config["path"])
|
92 |
+
labels = config["labels"]
|
93 |
+
imgsz = config["imgsz"]
|
94 |
+
|
95 |
+
img_rgb = np.array(input_image.convert("RGB"))
|
96 |
+
img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
|
97 |
+
|
98 |
+
results = model(img_bgr, conf=conf_threshold, imgsz=imgsz)
|
99 |
+
annotated_bgr = draw_detections(img_bgr, results, labels, keypoint_threshold, show_labels, point_size, point_color, label_size)
|
100 |
+
annotated_rgb = cv2.cvtColor(annotated_bgr, cv2.COLOR_BGR2RGB)
|
101 |
+
return Image.fromarray(annotated_rgb)
|
102 |
+
|
103 |
+
|
104 |
+
# Gradio Interface
|
105 |
+
app = gr.Interface(
|
106 |
+
fn=predict_and_annotate,
|
107 |
+
inputs=[
|
108 |
+
gr.Image(type="pil", label="Upload Image"),
|
109 |
+
gr.Slider(minimum=0.0, maximum=1.0, value=0.25, step=0.01, label="Confidence Threshold"),
|
110 |
+
gr.Slider(minimum=0.0, maximum=1.0, value=0.5, step=0.01, label="Keypoint Visibility Threshold"),
|
111 |
+
gr.Radio(choices=list(MODEL_CONFIGS.keys()), label="Select Model", value="Dry Season Form"),
|
112 |
+
gr.Checkbox(label="Show Keypoint Labels", value=True),
|
113 |
+
gr.Slider(minimum=1, maximum=20, value=8, step=0.1, label="Keypoint Size"),
|
114 |
+
gr.ColorPicker(label="Keypoint Color", value="#00FF00"),
|
115 |
+
gr.Slider(minimum=0.3, maximum=3.0, value=1.0, step=0.1, label="Keypoint Label Font Size")
|
116 |
+
],
|
117 |
+
outputs=gr.Image(type="pil", label="Detection Result", format="png"),
|
118 |
+
title="🦋 Melanitis leda Landmark Identification",
|
119 |
+
description="Upload an image and select the model. Customize detection and keypoint display settings.",
|
120 |
+
flagging_mode="never"
|
121 |
+
)
|
122 |
+
|
123 |
+
|
124 |
+
if __name__ == "__main__":
|
125 |
+
app.launch()
|
models/DSF_Mleda_250.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5b51ea730cb1c03399bc7d1de9fe51ca6f531af0d29c3064de9638a4a8ec701c
|
3 |
+
size 56864045
|
models/DSF_WSF_Mleda_450.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:59cd20c486514ff18080413657d784f76c95937db5d5f6fa8cfae1f95b24951a
|
3 |
+
size 56674033
|
models/WSF_Mleda_200.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:33b8298a9121ae3ccf5b5d5aff000d6dddda6f6e1c3e35a09f1215b8f53dd47c
|
3 |
+
size 53621489
|
requirements.txt
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ultralytics==8.3.88
|
2 |
+
torch==2.6.0
|
3 |
+
torchvision==0.21.0
|
4 |
+
pycocotools==2.0.7
|
5 |
+
PyYAML==6.0.1
|
6 |
+
scipy==1.15.2
|
7 |
+
gradio==5.29.0
|
8 |
+
numpy==1.26.4
|
9 |
+
opencv-python==4.9.0.80
|
10 |
+
psutil==5.9.8
|
11 |
+
py-cpuinfo==9.0.0
|
12 |
+
safetensors==0.4.3
|
13 |
+
pillow==11.1.0
|
14 |
+
opencv-python==4.11.0.86
|
15 |
+
opencv-python-headless==4.7.0.72
|