developer0hye commited on
Commit
4a4a3ed
·
verified ·
1 Parent(s): e85fecb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +260 -0
app.py CHANGED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2024 The D-FINE Authors. All Rights Reserved.
3
+ """
4
+
5
+ import os
6
+ import sys
7
+ import torch
8
+ import torch.nn as nn
9
+ import torchvision.transforms as T
10
+ import supervision as sv
11
+ from PIL import Image
12
+ import requests
13
+ import yaml
14
+ import gradio as gr
15
+ import numpy as np
16
+
17
+ from src.core import YAMLConfig
18
+
19
+
20
+ model_configs = {
21
+ "dfine_n_coco":
22
+ {"cfgfile": "configs/dfine/dfine_hgnetv2_n_coco.yml",
23
+ "classinfofile": "configs/coco.yml",
24
+ "weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_n_coco.pth"},
25
+ "dfine_s_coco":
26
+ {"cfgfile": "configs/dfine/dfine_hgnetv2_s_coco.yml",
27
+ "classinfofile": "configs/coco.yml",
28
+ "weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_s_coco.pth"},
29
+ "dfine_m_coco":
30
+ {"cfgfile": "configs/dfine/dfine_hgnetv2_m_coco.yml",
31
+ "classinfofile": "configs/coco.yml",
32
+ "weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_m_coco.pth"},
33
+ "dfine_l_coco":
34
+ {"cfgfile": "configs/dfine/dfine_hgnetv2_l_coco.yml",
35
+ "classinfofile": "configs/coco.yml",
36
+ "weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_l_coco.pth"},
37
+ "dfine_x_coco":
38
+ {"cfgfile": "configs/dfine/dfine_hgnetv2_x_coco.yml",
39
+ "classinfofile": "configs/coco.yml",
40
+ "weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_x_coco.pth"},
41
+ "dfine_s_obj365":
42
+ {"cfgfile": "configs/dfine/objects365/dfine_hgnetv2_s_obj365.yml",
43
+ "classinfofile": "configs/obj365.yml",
44
+ "weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_s_obj365.pth"},
45
+ "dfine_m_obj365":
46
+ {"cfgfile": "configs/dfine/objects365/dfine_hgnetv2_m_obj365.yml",
47
+ "classinfofile": "configs/obj365.yml",
48
+ "weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_m_obj365.pth"},
49
+ "dfine_l_obj365":
50
+ {"cfgfile": "configs/dfine/objects365/dfine_hgnetv2_l_obj365.yml",
51
+ "classinfofile": "configs/obj365.yml",
52
+ "weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_l_obj365.pth"},
53
+ "dfine_l_obj365_e25":
54
+ {"cfgfile": "configs/dfine/objects365/dfine_hgnetv2_l_obj365.yml",
55
+ "classinfofile": "configs/obj365.yml",
56
+ "weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_l_obj365_e25.pth"},
57
+ "dfine_x_obj365":
58
+ {"cfgfile": "configs/dfine/objects365/dfine_hgnetv2_x_obj365.yml",
59
+ "classinfofile": "configs/obj365.yml",
60
+ "weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_x_obj365.pth"},
61
+ "dfine_s_obj2coco":
62
+ {"cfgfile": "configs/dfine/objects365/dfine_hgnetv2_s_obj2coco.yml",
63
+ "classinfofile": "configs/coco.yml",
64
+ "weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_s_obj2coco.pth"},
65
+ "dfine_m_obj2coco":
66
+ {"cfgfile": "configs/dfine/objects365/dfine_hgnetv2_m_obj2coco.yml",
67
+ "classinfofile": "configs/coco.yml",
68
+ "weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_m_obj2coco.pth"},
69
+ "dfine_l_obj2coco_e25":
70
+ {"cfgfile": "configs/dfine/objects365/dfine_hgnetv2_l_obj2coco.yml",
71
+ "classinfofile": "configs/coco.yml",
72
+ "weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_l_obj2coco_e25.pth"},
73
+ "dfine_x_obj2coco":
74
+ {"cfgfile": "configs/dfine/objects365/dfine_hgnetv2_x_obj2coco.yml",
75
+ "classinfofile": "configs/coco.yml",
76
+ "weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_x_obj2coco.pth"},
77
+ }
78
+
79
+
80
+ def download_weights(model_name):
81
+ """Download model weights if not already present"""
82
+ weights_url = model_configs[model_name]["weights"]
83
+ # Directory path to save weight files
84
+ weights_dir = os.path.join(os.path.dirname(__file__), "weights")
85
+ # Weight file path
86
+ weights_path = os.path.join(weights_dir, model_name + ".pth")
87
+
88
+ # Create weights directory if it doesn't exist
89
+ if not os.path.exists(weights_dir):
90
+ os.makedirs(weights_dir)
91
+ print(f"Created directory: {weights_dir}")
92
+
93
+ # Check if file already exists
94
+ if os.path.exists(weights_path):
95
+ print(f"Weights file already exists at: {weights_path}")
96
+ return weights_path
97
+
98
+ # Download file
99
+ print(f"Downloading weights from {weights_url} to {weights_path}...")
100
+
101
+ response = requests.get(weights_url, stream=True)
102
+ response.raise_for_status() # Check for download errors
103
+
104
+ with open(weights_path, 'wb') as f:
105
+ for chunk in response.iter_content(chunk_size=8192):
106
+ f.write(chunk)
107
+
108
+ print(f"Downloaded weights to: {weights_path}")
109
+ return weights_path
110
+
111
+
112
+ def process_image_for_gradio(model, device, image, model_name, threshold=0.4):
113
+ """Process image function for Gradio interface"""
114
+ if isinstance(image, np.ndarray):
115
+ # Convert NumPy array to PIL image
116
+ im_pil = Image.fromarray(image)
117
+ else:
118
+ im_pil = image
119
+
120
+ # Load class information
121
+ classinfofile = model_configs[model_name]["classinfofile"]
122
+ classinfo = yaml.load(open(classinfofile, "r"), Loader=yaml.FullLoader)["names"]
123
+ indexing_method = "0-based" if "coco" in classinfofile else "1-based"
124
+
125
+ w, h = im_pil.size
126
+ orig_size = torch.tensor([[w, h]]).to(device)
127
+
128
+ transforms = T.Compose(
129
+ [
130
+ T.Resize((640, 640)),
131
+ T.ToTensor(),
132
+ ]
133
+ )
134
+ im_data = transforms(im_pil).unsqueeze(0).to(device)
135
+
136
+ output = model(im_data, orig_size)
137
+ labels, boxes, scores = output
138
+
139
+ # Visualize results
140
+ detections = sv.Detections(
141
+ xyxy=boxes[0].detach().cpu().numpy(),
142
+ confidence=scores[0].detach().cpu().numpy(),
143
+ class_id=labels[0].detach().cpu().numpy().astype(int),
144
+ )
145
+ detections = detections[detections.confidence > threshold]
146
+
147
+ text_scale = sv.calculate_optimal_text_scale(resolution_wh=im_pil.size)
148
+ line_thickness = sv.calculate_optimal_line_thickness(resolution_wh=im_pil.size)
149
+
150
+ box_annotator = sv.BoxAnnotator(thickness=line_thickness)
151
+ label_annotator = sv.LabelAnnotator(text_scale=text_scale, smart_position=True)
152
+
153
+ label_texts = [
154
+ f"{classinfo[class_id if indexing_method == '0-based' else class_id - 1]} {confidence:.2f}"
155
+ for class_id, confidence
156
+ in zip(detections.class_id, detections.confidence)
157
+ ]
158
+
159
+ result_image = im_pil.copy()
160
+ result_image = box_annotator.annotate(scene=result_image, detections=detections)
161
+ result_image = label_annotator.annotate(
162
+ scene=result_image,
163
+ detections=detections,
164
+ labels=label_texts
165
+ )
166
+
167
+ detection_info = [
168
+ f"{classinfo[class_id if indexing_method == '0-based' else class_id - 1]}: {confidence:.2f}, bbox: [{xyxy[0]:.1f}, {xyxy[1]:.1f}, {xyxy[2]:.1f}, {xyxy[3]:.1f}]"
169
+ for class_id, confidence, xyxy
170
+ in zip(detections.class_id, detections.confidence, detections.xyxy)
171
+ ]
172
+
173
+ return result_image, "\n".join(detection_info)
174
+
175
+
176
+ class ModelWrapper(nn.Module):
177
+ def __init__(self, cfg):
178
+ super().__init__()
179
+ self.model = cfg.model.deploy()
180
+ self.postprocessor = cfg.postprocessor.deploy()
181
+
182
+ def forward(self, images, orig_target_sizes):
183
+ outputs = self.model(images)
184
+ outputs = self.postprocessor(outputs, orig_target_sizes)
185
+ return outputs
186
+
187
+
188
+ def load_model(model_name):
189
+ cfgfile = model_configs[model_name]["cfgfile"]
190
+ weights_path = download_weights(model_name)
191
+
192
+ cfg = YAMLConfig(cfgfile, resume=weights_path)
193
+
194
+ if "HGNetv2" in cfg.yaml_cfg:
195
+ cfg.yaml_cfg["HGNetv2"]["pretrained"] = False
196
+
197
+ checkpoint = torch.load(weights_path, map_location="cpu")
198
+ state = checkpoint["ema"]["module"] if "ema" in checkpoint else checkpoint["model"]
199
+
200
+ cfg.model.load_state_dict(state)
201
+
202
+ device = "cuda" if torch.cuda.is_available() else "cpu"
203
+ model = ModelWrapper(cfg).to(device)
204
+ model.eval()
205
+
206
+ return model, device
207
+
208
+
209
+ # Dictionary to store loaded models
210
+ loaded_models = {}
211
+
212
+ def process_image(image, model_name, confidence_threshold):
213
+ """Main processing function for Gradio interface"""
214
+ global loaded_models
215
+
216
+ # Load model if not already loaded
217
+ if model_name not in loaded_models:
218
+ print(f"Loading model: {model_name}")
219
+ model, device = load_model(model_name)
220
+ loaded_models[model_name] = (model, device)
221
+ else:
222
+ print(f"Using cached model: {model_name}")
223
+ model, device = loaded_models[model_name]
224
+
225
+ # Process the image
226
+ return process_image_for_gradio(model, device, image, model_name, confidence_threshold)
227
+
228
+
229
+ # Create Gradio interface
230
+ demo = gr.Interface(
231
+ fn=process_image,
232
+ inputs=[
233
+ gr.Image(type="pil", label="Input Image"),
234
+ gr.Dropdown(
235
+ choices=list(model_configs.keys()),
236
+ value="dfine_n_coco",
237
+ label="Model Selection"
238
+ ),
239
+ gr.Slider(
240
+ minimum=0.1,
241
+ maximum=0.9,
242
+ value=0.4,
243
+ step=0.05,
244
+ label="Confidence Threshold"
245
+ )
246
+ ],
247
+ outputs=[
248
+ gr.Image(type="pil", label="Detection Result"),
249
+ gr.Textbox(label="Detected Objects")
250
+ ],
251
+ title="D-FINE Object Detection Demo",
252
+ description="Upload an image to see object detection results using the D-FINE model. You can select different models and adjust the confidence threshold.",
253
+ examples=[
254
+ ["examples/image1.jpg", "dfine_n_coco", 0.4],
255
+ ]
256
+ )
257
+
258
+ if __name__ == "__main__":
259
+ # Launch the Gradio app
260
+ demo.launch(share=True)