JasonSmithSO's picture
Upload 777 files
0034848 verified
import os
import types
import warnings
import cv2
import numpy as np
import torch
import torchvision.transforms as transforms
from einops import rearrange
from PIL import Image
from custom_controlnet_aux.util import HWC3, common_input_validate, resize_image_with_pad, custom_hf_download, DIFFUSION_EDGE_MODEL_NAME
from .models.dsine_arch import DSINE
from custom_controlnet_aux.dsine.utils.utils import get_intrins_from_fov
# load model
def load_checkpoint(fpath, model):
ckpt = torch.load(fpath, map_location='cpu')['model']
load_dict = {}
for k, v in ckpt.items():
if k.startswith('module.'):
k_ = k.replace('module.', '')
load_dict[k_] = v
else:
load_dict[k] = v
model.load_state_dict(load_dict)
return model
def get_pad(orig_H, orig_W):
if orig_W % 64 == 0:
l = 0
r = 0
else:
new_W = 64 * ((orig_W // 64) + 1)
l = (new_W - orig_W) // 2
r = (new_W - orig_W) - l
if orig_H % 64 == 0:
t = 0
b = 0
else:
new_H = 64 * ((orig_H // 64) + 1)
t = (new_H - orig_H) // 2
b = (new_H - orig_H) - t
return l, r, t, b
class DsineDetector:
def __init__(self, model):
self.model = model
self.norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
self.device = "cpu"
@classmethod
def from_pretrained(cls, pretrained_model_or_path=DIFFUSION_EDGE_MODEL_NAME, filename="dsine.pt"):
model_path = custom_hf_download(pretrained_model_or_path, filename)
model = DSINE()
model = load_checkpoint(model_path, model)
model.eval()
return cls(model)
def to(self, device):
self.model.to(device)
self.model.pixel_coords = self.model.pixel_coords.to(device)
self.device = device
return self
def __call__(self, input_image, fov=60.0, iterations=5, detect_resolution=512, output_type="pil", upscale_method="INTER_CUBIC", **kwargs):
self.model.num_iter = iterations
input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
orig_H, orig_W = input_image.shape[:2]
l, r, t, b = get_pad(orig_H, orig_W)
input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method, mode="constant")
with torch.no_grad():
input_image = torch.from_numpy(input_image).float().to(self.device)
input_image = input_image / 255.0
input_image = rearrange(input_image, 'h w c -> 1 c h w')
input_image = self.norm(input_image)
intrins = get_intrins_from_fov(new_fov=fov, H=orig_H, W=orig_W, device=self.device).unsqueeze(0)
intrins[:, 0, 2] += l
intrins[:, 1, 2] += t
normal = self.model(input_image, intrins)
normal = normal[-1][0]
normal = ((normal + 1) * 0.5).clip(0, 1)
normal = rearrange(normal, 'c h w -> h w c').cpu().numpy()
normal_image = (normal * 255.0).clip(0, 255).astype(np.uint8)
detected_map = HWC3(normal_image)
if output_type == "pil":
detected_map = Image.fromarray(detected_map)
return detected_map