Spaces:
Configuration error
Configuration error
import os | |
import types | |
import warnings | |
import cv2 | |
import numpy as np | |
import torch | |
import torchvision.transforms as transforms | |
from einops import rearrange | |
from PIL import Image | |
from custom_controlnet_aux.util import HWC3, common_input_validate, resize_image_with_pad, custom_hf_download, DIFFUSION_EDGE_MODEL_NAME | |
from .models.dsine_arch import DSINE | |
from custom_controlnet_aux.dsine.utils.utils import get_intrins_from_fov | |
# load model | |
def load_checkpoint(fpath, model): | |
ckpt = torch.load(fpath, map_location='cpu')['model'] | |
load_dict = {} | |
for k, v in ckpt.items(): | |
if k.startswith('module.'): | |
k_ = k.replace('module.', '') | |
load_dict[k_] = v | |
else: | |
load_dict[k] = v | |
model.load_state_dict(load_dict) | |
return model | |
def get_pad(orig_H, orig_W): | |
if orig_W % 64 == 0: | |
l = 0 | |
r = 0 | |
else: | |
new_W = 64 * ((orig_W // 64) + 1) | |
l = (new_W - orig_W) // 2 | |
r = (new_W - orig_W) - l | |
if orig_H % 64 == 0: | |
t = 0 | |
b = 0 | |
else: | |
new_H = 64 * ((orig_H // 64) + 1) | |
t = (new_H - orig_H) // 2 | |
b = (new_H - orig_H) - t | |
return l, r, t, b | |
class DsineDetector: | |
def __init__(self, model): | |
self.model = model | |
self.norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
self.device = "cpu" | |
def from_pretrained(cls, pretrained_model_or_path=DIFFUSION_EDGE_MODEL_NAME, filename="dsine.pt"): | |
model_path = custom_hf_download(pretrained_model_or_path, filename) | |
model = DSINE() | |
model = load_checkpoint(model_path, model) | |
model.eval() | |
return cls(model) | |
def to(self, device): | |
self.model.to(device) | |
self.model.pixel_coords = self.model.pixel_coords.to(device) | |
self.device = device | |
return self | |
def __call__(self, input_image, fov=60.0, iterations=5, detect_resolution=512, output_type="pil", upscale_method="INTER_CUBIC", **kwargs): | |
self.model.num_iter = iterations | |
input_image, output_type = common_input_validate(input_image, output_type, **kwargs) | |
orig_H, orig_W = input_image.shape[:2] | |
l, r, t, b = get_pad(orig_H, orig_W) | |
input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method, mode="constant") | |
with torch.no_grad(): | |
input_image = torch.from_numpy(input_image).float().to(self.device) | |
input_image = input_image / 255.0 | |
input_image = rearrange(input_image, 'h w c -> 1 c h w') | |
input_image = self.norm(input_image) | |
intrins = get_intrins_from_fov(new_fov=fov, H=orig_H, W=orig_W, device=self.device).unsqueeze(0) | |
intrins[:, 0, 2] += l | |
intrins[:, 1, 2] += t | |
normal = self.model(input_image, intrins) | |
normal = normal[-1][0] | |
normal = ((normal + 1) * 0.5).clip(0, 1) | |
normal = rearrange(normal, 'c h w -> h w c').cpu().numpy() | |
normal_image = (normal * 255.0).clip(0, 255).astype(np.uint8) | |
detected_map = HWC3(normal_image) | |
if output_type == "pil": | |
detected_map = Image.fromarray(detected_map) | |
return detected_map | |