|
""" |
|
Copyright (c) 2024 The D-FINE Authors. All Rights Reserved. |
|
""" |
|
import gradio as gr |
|
import spaces |
|
import os |
|
import sys |
|
import torch |
|
import torch.nn as nn |
|
import torchvision.transforms as T |
|
import supervision as sv |
|
from PIL import Image |
|
import requests |
|
import yaml |
|
import numpy as np |
|
import gc |
|
|
|
from src.core import YAMLConfig |
|
|
|
|
|
model_configs = { |
|
"dfine_n_custom": |
|
{"cfgfile": "configs/dfine/custom/dfine_hgnetv2_n_custom.yml", |
|
"classinfofile": "configs/custom_info_coco.yml", |
|
"weights": "https://github.com/EnPaiva93/storage/raw/refs/heads/main/custom_model.pth"}, |
|
"dfine_n_coco": |
|
{"cfgfile": "configs/dfine/dfine_hgnetv2_n_coco.yml", |
|
"classinfofile": "configs/coco.yml", |
|
"weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_n_coco.pth"}, |
|
"dfine_s_coco": |
|
{"cfgfile": "configs/dfine/dfine_hgnetv2_s_coco.yml", |
|
"classinfofile": "configs/coco.yml", |
|
"weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_s_coco.pth"}, |
|
"dfine_m_coco": |
|
{"cfgfile": "configs/dfine/dfine_hgnetv2_m_coco.yml", |
|
"classinfofile": "configs/coco.yml", |
|
"weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_m_coco.pth"}, |
|
"dfine_l_coco": |
|
{"cfgfile": "configs/dfine/dfine_hgnetv2_l_coco.yml", |
|
"classinfofile": "configs/coco.yml", |
|
"weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_l_coco.pth"}, |
|
"dfine_x_coco": |
|
{"cfgfile": "configs/dfine/dfine_hgnetv2_x_coco.yml", |
|
"classinfofile": "configs/coco.yml", |
|
"weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_x_coco.pth"}, |
|
"dfine_s_obj365": |
|
{"cfgfile": "configs/dfine/objects365/dfine_hgnetv2_s_obj365.yml", |
|
"classinfofile": "configs/obj365.yml", |
|
"weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_s_obj365.pth"}, |
|
"dfine_m_obj365": |
|
{"cfgfile": "configs/dfine/objects365/dfine_hgnetv2_m_obj365.yml", |
|
"classinfofile": "configs/obj365.yml", |
|
"weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_m_obj365.pth"}, |
|
"dfine_l_obj365": |
|
{"cfgfile": "configs/dfine/objects365/dfine_hgnetv2_l_obj365.yml", |
|
"classinfofile": "configs/obj365.yml", |
|
"weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_l_obj365.pth"}, |
|
"dfine_l_obj365_e25": |
|
{"cfgfile": "configs/dfine/objects365/dfine_hgnetv2_l_obj365.yml", |
|
"classinfofile": "configs/obj365.yml", |
|
"weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_l_obj365_e25.pth"}, |
|
"dfine_x_obj365": |
|
{"cfgfile": "configs/dfine/objects365/dfine_hgnetv2_x_obj365.yml", |
|
"classinfofile": "configs/obj365.yml", |
|
"weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_x_obj365.pth"}, |
|
"dfine_s_obj2coco": |
|
{"cfgfile": "configs/dfine/objects365/dfine_hgnetv2_s_obj2coco.yml", |
|
"classinfofile": "configs/coco.yml", |
|
"weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_s_obj2coco.pth"}, |
|
"dfine_m_obj2coco": |
|
{"cfgfile": "configs/dfine/objects365/dfine_hgnetv2_m_obj2coco.yml", |
|
"classinfofile": "configs/coco.yml", |
|
"weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_m_obj2coco.pth"}, |
|
"dfine_l_obj2coco_e25": |
|
{"cfgfile": "configs/dfine/objects365/dfine_hgnetv2_l_obj2coco.yml", |
|
"classinfofile": "configs/coco.yml", |
|
"weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_l_obj2coco_e25.pth"}, |
|
"dfine_x_obj2coco": |
|
{"cfgfile": "configs/dfine/objects365/dfine_hgnetv2_x_obj2coco.yml", |
|
"classinfofile": "configs/coco.yml", |
|
"weights": "https://github.com/Peterande/storage/releases/download/dfinev1.0/dfine_x_obj2coco.pth"}, |
|
} |
|
|
|
|
|
def download_weights(model_name): |
|
"""Download model weights if not already present""" |
|
weights_url = model_configs[model_name]["weights"] |
|
|
|
weights_dir = os.path.join(os.path.dirname(__file__), "weights") |
|
|
|
weights_path = os.path.join(weights_dir, model_name + ".pth") |
|
|
|
|
|
if not os.path.exists(weights_dir): |
|
os.makedirs(weights_dir) |
|
print(f"Created directory: {weights_dir}") |
|
|
|
|
|
if os.path.exists(weights_path): |
|
print(f"Weights file already exists at: {weights_path}") |
|
return weights_path |
|
|
|
|
|
print(f"Downloading weights from {weights_url} to {weights_path}...") |
|
|
|
response = requests.get(weights_url, stream=True) |
|
response.raise_for_status() |
|
|
|
with open(weights_path, 'wb') as f: |
|
for chunk in response.iter_content(chunk_size=8192): |
|
f.write(chunk) |
|
|
|
print(f"Downloaded weights to: {weights_path}") |
|
return weights_path |
|
|
|
@torch.no_grad() |
|
def process_image_for_gradio(model, device, image, model_name, threshold=0.4): |
|
"""Process image function for Gradio interface""" |
|
if isinstance(image, np.ndarray): |
|
|
|
im_pil = Image.fromarray(image) |
|
else: |
|
im_pil = image |
|
|
|
|
|
classinfofile = model_configs[model_name]["classinfofile"] |
|
classinfo = yaml.load(open(classinfofile, "r"), Loader=yaml.FullLoader)["names"] |
|
indexing_method = "0-based" if "coco" in classinfofile else "1-based" |
|
|
|
w, h = im_pil.size |
|
orig_size = torch.tensor([[w, h]]).to(device) |
|
|
|
transforms = T.Compose( |
|
[ |
|
T.Resize((640, 640)), |
|
T.ToTensor(), |
|
] |
|
) |
|
im_data = transforms(im_pil).unsqueeze(0).to(device) |
|
|
|
output = model(im_data, orig_size) |
|
labels, boxes, scores = output |
|
|
|
|
|
detections = sv.Detections( |
|
xyxy=boxes[0].detach().cpu().numpy(), |
|
confidence=scores[0].detach().cpu().numpy(), |
|
class_id=labels[0].detach().cpu().numpy().astype(int), |
|
) |
|
detections = detections[detections.confidence > threshold] |
|
|
|
text_scale = sv.calculate_optimal_text_scale(resolution_wh=im_pil.size) |
|
line_thickness = sv.calculate_optimal_line_thickness(resolution_wh=im_pil.size) |
|
|
|
box_annotator = sv.BoxAnnotator(thickness=line_thickness) |
|
label_annotator = sv.LabelAnnotator(text_scale=text_scale, smart_position=True) |
|
|
|
label_texts = [ |
|
f"{classinfo[class_id if indexing_method == '0-based' else class_id - 1]} {confidence:.2f}" |
|
for class_id, confidence |
|
in zip(detections.class_id, detections.confidence) |
|
] |
|
|
|
result_image = im_pil.copy() |
|
result_image = box_annotator.annotate(scene=result_image, detections=detections) |
|
result_image = label_annotator.annotate( |
|
scene=result_image, |
|
detections=detections, |
|
labels=label_texts |
|
) |
|
|
|
detection_info = [ |
|
f"{classinfo[class_id if indexing_method == '0-based' else class_id - 1]}: {confidence:.2f}, bbox: [{xyxy[0]:.1f}, {xyxy[1]:.1f}, {xyxy[2]:.1f}, {xyxy[3]:.1f}]" |
|
for class_id, confidence, xyxy |
|
in zip(detections.class_id, detections.confidence, detections.xyxy) |
|
] |
|
|
|
return result_image, "\n".join(detection_info) |
|
|
|
|
|
class ModelWrapper(nn.Module): |
|
def __init__(self, cfg): |
|
super().__init__() |
|
self.model = cfg.model.deploy() |
|
self.postprocessor = cfg.postprocessor.deploy() |
|
|
|
def forward(self, images, orig_target_sizes): |
|
outputs = self.model(images) |
|
outputs = self.postprocessor(outputs, orig_target_sizes) |
|
return outputs |
|
|
|
|
|
|
|
def reset_yaml_config(): |
|
"""YAMLConfig ํด๋์ค์ ๋ด๋ถ ์ํ๋ฅผ ์ด๊ธฐํ""" |
|
|
|
if hasattr(YAMLConfig, '_instances'): |
|
YAMLConfig._instances = {} |
|
if hasattr(YAMLConfig, '_configs'): |
|
YAMLConfig._configs = {} |
|
|
|
|
|
import importlib |
|
for module_name in list(sys.modules.keys()): |
|
if module_name.startswith('src.'): |
|
try: |
|
importlib.reload(sys.modules[module_name]) |
|
except: |
|
pass |
|
|
|
def load_model(model_name): |
|
|
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
|
|
reset_yaml_config() |
|
|
|
cfgfile = model_configs[model_name]["cfgfile"] |
|
weights_path = download_weights(model_name) |
|
|
|
|
|
cfg = YAMLConfig(cfgfile, resume=weights_path) |
|
|
|
if "HGNetv2" in cfg.yaml_cfg: |
|
cfg.yaml_cfg["HGNetv2"]["pretrained"] = False |
|
|
|
checkpoint = torch.load(weights_path, map_location="cpu") |
|
state = checkpoint["ema"]["module"] if "ema" in checkpoint else checkpoint["model"] |
|
|
|
|
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
cfg.model.load_state_dict(state, strict=False) |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
model = ModelWrapper(cfg).to(device) |
|
model.eval() |
|
|
|
return model, device |
|
|
|
@spaces.GPU |
|
def process_image(image, model_name, confidence_threshold): |
|
"""Main processing function for Gradio interface""" |
|
|
|
|
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
|
|
gc.collect() |
|
|
|
try: |
|
print(f"Loading model: {model_name}") |
|
model, device = load_model(model_name) |
|
|
|
|
|
result = process_image_for_gradio(model, device, image, model_name, confidence_threshold) |
|
|
|
|
|
del model |
|
|
|
finally: |
|
|
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
return result |
|
|
|
|
|
|
|
demo = gr.Interface( |
|
fn=process_image, |
|
inputs=[ |
|
gr.Image(type="pil", label="Input Image"), |
|
gr.Dropdown( |
|
choices=list(model_configs.keys()), |
|
value="dfine_n_coco", |
|
label="Model Selection" |
|
), |
|
gr.Slider( |
|
minimum=0.1, |
|
maximum=0.9, |
|
value=0.4, |
|
step=0.05, |
|
label="Confidence Threshold" |
|
) |
|
], |
|
outputs=[ |
|
gr.Image(type="pil", label="Detection Result"), |
|
gr.Textbox(label="Detected Objects") |
|
], |
|
title="D-FINE Object Detection Demo", |
|
description="Upload an image to see object detection results using the D-FINE model. You can select different models and adjust the confidence threshold.", |
|
examples=[ |
|
["examples/image1.jpg", "dfine_n_coco", 0.4], |
|
] |
|
) |
|
|
|
if __name__ == "__main__": |
|
|
|
demo.launch(share=True) |