import torch import torch.nn as nn import torchvision.transforms as T import gradio as gr from PIL import Image from copy import deepcopy import os, sys sys.path.append('./DETRPose') sys.path.append('./DETRPose/tools/inference') from DETRPose.src.core import LazyConfig, instantiate from DETRPose.tools.inference.annotator import Annotator from DETRPose.tools.inference.annotator_crowdpose import AnnotatorCrowdpose DETRPOSE_MODELS = { # For COCO2017 "DETRPose-N": ['./DETRPose/configs/detrpose/detrpose_hgnetv2_n.py', 'n'], "DETRPose-S": ['./DETRPose/configs/detrpose/detrpose_hgnetv2_s.py', 's'], "DETRPose-M": ['./DETRPose/configs/detrpose/detrpose_hgnetv2_m.py', 'm'], "DETRPose-L": ['./DETRPose/configs/detrpose/detrpose_hgnetv2_l.py', 'l'], "DETRPose-X": ['./DETRPose/configs/detrpose/detrpose_hgnetv2_x.py', 'x'], # For CrowdPose "DETRPose-N-CrowdPose": ['./DETRPose/configs/detrpose/detrpose_hgnetv2_n_crowdpose.py', 'n_crowdpose'], "DETRPose-S-CrowdPose": ['./DETRPose/configs/detrpose/detrpose_hgnetv2_s_crowdpose.py', 's_crowdpose'], "DETRPose-M-CrowdPose": ['./DETRPose/configs/detrpose/detrpose_hgnetv2_m_crowdpose.py', 'm_crowdpose'], "DETRPose-L-CrowdPose": ['./DETRPose/configs/detrpose/detrpose_hgnetv2_l_crowdpose.py', 'l_crowdpose'], "DETRPose-X-CrowdPose": ['./DETRPose/configs/detrpose/detrpose_hgnetv2_x_crowdpose.py', 'x_crowdpose'], } transforms = T.Compose( [ T.Resize((640, 640)), T.ToTensor(), ] ) example_images = [ ["assets/example1.jpg"], ["assets/example2.jpg"], ] description = """

DETRPose
Real-time end-to-end transformer model for multi-person pose estimation

Sebastian Janampa and Marios Pattichis

GitHub | Colab

## Getting Started DETRPose is the first real-time end-to-end transformer model for multi-person pose estimation, achieving outstanding results on the COCO and CrowdPose datasets. In this work, we propose a new denoising technique suitable for pose estimation that uses the Object Keypoint Similarity (OKS) metric to generate positive and negative queries. Additionally, we develop a new classification head and a new classification loss that are variations of the LQE head and the varifocal loss used in D-FINE. To get started, upload an image or select one of the examples below. You can choose between different model size, change the confidence threshold and visualize the results. ### Acknowledgement This work has been supported by [LambdaLab](https://lambda.ai) """ def create_model(model_name): config_path = DETRPOSE_MODELS[model_name][0] model_name = DETRPOSE_MODELS[model_name][1] cfg = LazyConfig.load(config_path) if hasattr(cfg.model.backbone, 'pretrained'): cfg.model.backbone.pretrained = False download_url = f"https://github.com/SebastianJanampa/DETRPose/releases/download/model_weights/detrpose_hgnetv2_{model_name}.pth" state_dict = torch.hub.load_state_dict_from_url( download_url, map_location="cpu", file_name=f"detrpose_hgnetv2_{model_name}.pth" ) model = instantiate(cfg.model) postprocessor = instantiate(cfg.postprocessor) model.load_state_dict(state_dict['model'], strict=True) class Model(nn.Module): def __init__(self): super().__init__() self.model = model.deploy() self.postprocessor = postprocessor.deploy() def forward(self, images, orig_target_sizes): outputs = self.model(images) outputs = self.postprocessor(outputs, orig_target_sizes) return outputs model = Model() model.eval() global Drawer if 'crowdpose' in model_name: Drawer = AnnotatorCrowdpose else: Drawer = Annotator return model#, Drawer def draw(image, scores, labels, keypoints, h, w, thrh): annotator = Drawer(deepcopy(image)) for kpt, score in zip(keypoints, scores): if score > thrh: annotator.kpts( kpt, [h, w] ) annotated_image = annotator.result() return annotated_image[..., ::-1] def filter(lines, scores, threshold): filtered_lines, filter_scores = [], [] for line, scr in zip(lines, scores): idx = scr > threshold filtered_lines.append(line[idx]) filter_scores.append(scr[idx]) return filtered_lines, filter_scores def process_results( image_path, model_size, threshold ): """ Process the image an returns the detected lines """ if image_path is None: raise gr.Error("Please upload an image first.") model = create_model(model_size) im_pil = Image.open(image_path).convert("RGB") w, h = im_pil.size orig_size = torch.tensor([[w, h]]) im_data = transforms(im_pil).unsqueeze(0) output = model(im_data, orig_size) scores, labels, keypoints = output scores, labels, keypoints = scores[0], labels[0], keypoints[0] annotated_image = draw(im_pil, scores, labels, keypoints, h, w, thrh=threshold) return annotated_image, (scores, labels, keypoints, h, w) def update_threshold( image_path, raw_results, threshold ): scores, labels, keypoints, h, w = raw_results im_pil = Image.open(image_path).convert("RGB") annotated_image = draw(im_pil, scores, labels, keypoints, h, w, thrh=threshold) return annotated_image def update_model( image_path, model_size, threshold ): if image_path is None: raise gr.Error("Please upload an image first.") return None, None, None return process_results(image_path, model_size, threshold) def main(): global Drawer # Create the Gradio interface with gr.Blocks() as demo: gr.Markdown(description) with gr.Row(): with gr.Column(): gr.Markdown("""## Input Image""") image_path = gr.Image(label="Upload image", type="filepath") model_size = gr.Dropdown( choices=list(DETRPOSE_MODELS.keys()), label="Choose a DETRPose model.", value="DETRPose-M" ) threshold = gr.Slider( label="Confidence Threshold", minimum=0.0, maximum=1.0, step=0.05, interactive=True, value=0.50, ) submit_btn = gr.Button("Detect Human Keypoints") gr.Examples(examples=example_images, inputs=[image_path, model_size]) with gr.Column(): gr.Markdown("""## Results""") image_output = gr.Image(label="Detected Human Keypoints") # Define the action when the button is clicked raw_results = gr.State() plot_inputs = [ raw_results, threshold, ] submit_btn.click( fn=process_results, inputs=[image_path, model_size] + plot_inputs[1:], outputs=[image_output, raw_results], ) # Define the action when the plot checkboxes are clicked threshold.change(fn=update_threshold, inputs=[image_path] + plot_inputs, outputs=[image_output]) demo.launch() if __name__ == "__main__": main()