import os import sys os.environ["PYOPENGL_PLATFORM"] = "egl" os.environ["MESA_GL_VERSION_OVERRIDE"] = "4.1" import gradio as gr #import spaces import cv2 import numpy as np import torch from ultralytics import YOLO from pathlib import Path import argparse import json from torchvision import transforms from typing import Dict, Optional from PIL import Image, ImageDraw from huggingface_hub import hf_hub_download from lang_sam import LangSAM from wilor.models import load_wilor from wilor.utils import recursive_to from wilor.datasets.vitdet_dataset import ViTDetDataset from hort.models import load_hort from hort.utils.renderer import Renderer, cam_crop_to_new from hort.utils.img_utils import process_bbox, generate_patch_image, PerspectiveCamera from ultralytics import YOLO LIGHT_PURPLE=(0.25098039, 0.274117647, 0.65882353) STEEL_BLUE=(0.2745098, 0.5098039, 0.7058824) wilor_checkpoint_path = hf_hub_download(repo_id="zerchen/hort_models", filename="wilor_final.ckpt") hort_checkpoint_path = hf_hub_download(repo_id="zerchen/hort_models", filename="hort_final.pth.tar") # Download and load checkpoints wilor_model, wilor_model_cfg = load_wilor(checkpoint_path = wilor_checkpoint_path, cfg_path= './pretrained_models/model_config.yaml') hand_detector = YOLO('./pretrained_models/detector.pt') # Setup the renderer renderer = Renderer(wilor_model_cfg, faces=wilor_model.mano.faces) # Setup the SAM model sam_model = LangSAM(sam_type="sam2.1_hiera_large") # Setup the HORT model hort_model = load_hort(hort_checkpoint_path) device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') wilor_model = wilor_model.to(device) hand_detector = hand_detector.to(device) hort_model = hort_model.to(device) wilor_model.eval() hort_model.eval() image_transform = transforms.Compose([transforms.ToPILImage(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) @spaces.GPU() def run_model(image, conf, IoU_threshold=0.5): img_cv2 = image[..., ::-1] img_pil = Image.fromarray(image) pred_obj = sam_model.predict([img_pil], ["manipulated object"]) pred_hand = sam_model.predict([img_pil], ["hand"]) bbox_obj = pred_obj[0]["boxes"][0].reshape((-1, 2)) mask_obj = pred_obj[0]["masks"][0] bbox_hand = pred_hand[0]["boxes"][0].reshape((-1, 2)) mask_hand = pred_hand[0]["masks"][0] tl = np.min(np.concatenate([bbox_obj, bbox_hand], axis=0), axis=0) br = np.max(np.concatenate([bbox_obj, bbox_hand], axis=0), axis=0) box_size = br - tl bbox = np.concatenate([tl - 10, box_size + 20], axis=0) ho_bbox = process_bbox(bbox) detections = hand_detector(img_cv2, conf=conf, verbose=False, iou=IoU_threshold)[0] bboxes = [] is_right = [] for det in detections: Bbox = det.boxes.data.cpu().detach().squeeze().numpy() is_right.append(det.boxes.cls.cpu().detach().squeeze().item()) bboxes.append(Bbox[:4].tolist()) if len(bboxes) == 1: boxes = np.stack(bboxes) right = np.stack(is_right) if not right: new_x1 = img_cv2.shape[1] - boxes[0][2] new_x2 = img_cv2.shape[1] - boxes[0][0] boxes[0][0] = new_x1 boxes[0][2] = new_x2 ho_bbox[0] = img_cv2.shape[1] - (ho_bbox[0] + ho_bbox[2]) img_cv2 = cv2.flip(img_cv2, 1) right[0] = 1. crop_img_cv2, _ = generate_patch_image(img_cv2, ho_bbox, (224, 224), 0, 1.0, 0) dataset = ViTDetDataset(wilor_model_cfg, img_cv2, boxes, right, rescale_factor=2.0) dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=False, num_workers=0) for batch in dataloader: batch = recursive_to(batch, device) with torch.no_grad(): out = wilor_model(batch) pred_cam = out['pred_cam'] box_center = batch["box_center"].float() box_size = batch["box_size"].float() img_size = batch["img_size"].float() scaled_focal_length = wilor_model_cfg.EXTRA.FOCAL_LENGTH / wilor_model_cfg.MODEL.IMAGE_SIZE * 224 pred_cam_t_full = cam_crop_to_new(pred_cam, box_center, box_size, img_size, torch.from_numpy(np.array(ho_bbox, dtype=np.float32))[None, :].to(img_size.device), scaled_focal_length).detach().cpu().numpy() batch_size = batch['img'].shape[0] for n in range(batch_size): verts = out['pred_vertices'][n].detach().cpu().numpy() joints = out['pred_keypoints_3d'][n].detach().cpu().numpy() is_right = batch['right'][n].cpu().numpy() palm = (verts[95] + verts[22]) / 2 cam_t = pred_cam_t_full[n] img_input = image_transform(crop_img_cv2[:, :, ::-1]).unsqueeze(0).cuda() camera = PerspectiveCamera(5000 / 256 * 224, 5000 / 256 * 224, 112, 112) cam_intr = camera.intrinsics metas = dict() metas["right_hand_verts_3d"] = torch.from_numpy((verts + cam_t)[None]).cuda() metas["right_hand_joints_3d"] = torch.from_numpy((joints + cam_t)[None]).cuda() metas["right_hand_palm"] = torch.from_numpy((palm + cam_t)[None]).cuda() metas["cam_intr"] = torch.from_numpy(cam_intr[None]).cuda() with torch.amp.autocast(device_type='cuda', dtype=torch.float16): pc_results = hort_model(img_input, metas) objtrans = pc_results["objtrans"][0].detach().cpu().numpy() pointclouds_up = pc_results["pointclouds_up"][0].detach().cpu().numpy() * 0.3 reconstructions = {'verts': verts, 'palm': palm, 'objtrans': objtrans, 'objpcs': pointclouds_up, 'cam_t': cam_t, 'right': is_right, 'img_size': 224, 'focal': scaled_focal_length} return crop_img_cv2[..., ::-1].astype(np.float32) / 255.0, len(detections), reconstructions else: return crop_img_cv2[..., ::-1].astype(np.float32) / 255.0, len(detections), None def render_reconstruction(image, conf, IoU_threshold=0.3): input_img, num_dets, reconstructions = run_model(image, conf, IoU_threshold=0.5) if num_dets == 1: # Render front view misc_args = dict(mesh_base_color=LIGHT_PURPLE, point_base_color=STEEL_BLUE, scene_bg_color=(1, 1, 1), focal_length=reconstructions['focal']) cam_view = renderer.render_rgba(reconstructions['verts'], reconstructions['objpcs'] + reconstructions['palm'] + reconstructions['objtrans'], cam_t=reconstructions['cam_t'], render_res=(224, 224), is_right=True, **misc_args) # Overlay image input_img = np.concatenate([input_img, np.ones_like(input_img[:,:,:1])], axis=2) # Add alpha channel input_img_overlay = input_img[:,:,:3] * (1-cam_view[:,:,3:]) + cam_view[:,:,:3] * cam_view[:,:,3:] return input_img_overlay, f'{num_dets} hands detected' else: return input_img, f'{num_dets} hands detected' header = ('''