Spaces:
Configuration error
Configuration error
import torch | |
import os | |
from pathlib import Path | |
CODE_SPACE=Path(os.path.dirname(os.path.abspath(__file__))) | |
from custom_mmpkg.custom_mmcv.utils import Config, DictAction | |
from custom_controlnet_aux.metric3d.mono.model.monodepth_model import get_configured_monodepth_model | |
from custom_controlnet_aux.metric3d.mono.utils.running import load_ckpt | |
from custom_controlnet_aux.metric3d.mono.utils.do_test import transform_test_data_scalecano, get_prediction | |
import numpy as np | |
from custom_controlnet_aux.metric3d.mono.utils.visualization import vis_surface_normal | |
from einops import repeat | |
from PIL import Image | |
from custom_controlnet_aux.util import HWC3, common_input_validate, resize_image_with_pad, custom_hf_download, METRIC3D_MODEL_NAME | |
import re | |
import matplotlib.pyplot as plt | |
def load_model(model_selection, model_path): | |
if model_selection == "vit-small": | |
cfg = Config.fromfile(CODE_SPACE / 'mono/configs/HourglassDecoder/vit.raft5.small.py') | |
elif model_selection == "vit-large": | |
cfg = Config.fromfile(CODE_SPACE / 'mono/configs/HourglassDecoder/vit.raft5.large.py') | |
elif model_selection == "vit-giant2": | |
cfg = Config.fromfile(CODE_SPACE / 'mono/configs/HourglassDecoder/vit.raft5.giant2.py') | |
else: | |
raise NotImplementedError(f"metric3d model: {model_selection}") | |
model = get_configured_monodepth_model(cfg, ) | |
model, _, _, _ = load_ckpt(model_path, model, strict_match=False) | |
model.eval() | |
model = model | |
return model, cfg | |
def gray_to_colormap(img, cmap='rainbow'): | |
""" | |
Transfer gray map to matplotlib colormap | |
""" | |
assert img.ndim == 2 | |
img[img<0] = 0 | |
mask_invalid = img < 1e-10 | |
img = img / (img.max() + 1e-8) | |
norm = plt.Normalize(vmin=0, vmax=1.1) # Use plt.Normalize instead of matplotlib.colors.Normalize | |
cmap_m = plt.get_cmap(cmap) # Access the colormap directly from plt | |
map = plt.cm.ScalarMappable(norm=norm, cmap=cmap_m) | |
colormap = (map.to_rgba(img)[:, :, :3] * 255).astype(np.uint8) | |
colormap[mask_invalid] = 0 | |
return colormap | |
def predict_depth_normal(model, cfg, np_img, fx=1000.0, fy=1000.0, state_cache={}): | |
intrinsic = [fx, fy, np_img.shape[1]/2, np_img.shape[0]/2] | |
rgb_input, cam_models_stacks, pad, label_scale_factor = transform_test_data_scalecano(np_img, intrinsic, cfg.data_basic, device=next(model.parameters()).device) | |
with torch.no_grad(): | |
pred_depth, confidence, output = get_prediction( | |
model = model, | |
input = rgb_input.unsqueeze(0), | |
cam_model = cam_models_stacks, | |
pad_info = pad, | |
scale_info = label_scale_factor, | |
gt_depth = None, | |
normalize_scale = cfg.data_basic.depth_range[1], | |
ori_shape=[np_img.shape[0], np_img.shape[1]], | |
) | |
pred_normal = output['normal_out_list'][0][:, :3, :, :] | |
H, W = pred_normal.shape[2:] | |
pred_normal = pred_normal[:, :, pad[0]:H-pad[1], pad[2]:W-pad[3]] | |
pred_depth = pred_depth[:, :, pad[0]:H-pad[1], pad[2]:W-pad[3] ] | |
pred_depth = pred_depth.squeeze().cpu().numpy() | |
pred_color = gray_to_colormap(pred_depth, 'Greys') | |
pred_normal = torch.nn.functional.interpolate(pred_normal, [np_img.shape[0], np_img.shape[1]], mode='bilinear').squeeze() | |
pred_normal = pred_normal.permute(1,2,0) | |
pred_color_normal = vis_surface_normal(pred_normal) | |
pred_normal = pred_normal.cpu().numpy() | |
# Storing depth and normal map in state for potential 3D reconstruction | |
state_cache['depth'] = pred_depth | |
state_cache['normal'] = pred_normal | |
state_cache['img'] = np_img | |
state_cache['intrinsic'] = intrinsic | |
state_cache['confidence'] = confidence | |
return pred_color, pred_color_normal, state_cache | |
class Metric3DDetector: | |
def __init__(self, model, cfg): | |
self.model = model | |
self.cfg = cfg | |
self.device = "cpu" | |
def from_pretrained(cls, pretrained_model_or_path=METRIC3D_MODEL_NAME, filename="metric_depth_vit_small_800k.pth"): | |
model_path = custom_hf_download(pretrained_model_or_path, filename) | |
backbone = re.findall(r"metric_depth_vit_(\w+)_", model_path)[0] | |
model, cfg = load_model(f'vit-{backbone}', model_path) | |
return cls(model, cfg) | |
def to(self, device): | |
self.model.to(device) | |
self.device = device | |
return self | |
def __call__(self, input_image, detect_resolution=512, fx=1000, fy=1000, output_type=None, upscale_method="INTER_CUBIC", depth_and_normal=True, **kwargs): | |
input_image, output_type = common_input_validate(input_image, output_type, **kwargs) | |
depth_map, normal_map, _ = predict_depth_normal(self.model, self.cfg, input_image, fx=fx, fy=fy) | |
# ControlNet uses inverse depth and normal | |
depth_map, normal_map = depth_map, 255 - normal_map | |
depth_map, remove_pad = resize_image_with_pad(depth_map, detect_resolution, upscale_method) | |
normal_map, _ = resize_image_with_pad(normal_map, detect_resolution, upscale_method) | |
depth_map, normal_map = remove_pad(depth_map), remove_pad(normal_map) | |
if output_type == "pil": | |
depth_map = Image.fromarray(depth_map) | |
normal_map = Image.fromarray(normal_map) | |
if depth_and_normal: | |
return depth_map, normal_map | |
else: | |
return depth_map | |