JasonSmithSO's picture
Upload 777 files
0034848 verified
import torch
import os
from pathlib import Path
CODE_SPACE=Path(os.path.dirname(os.path.abspath(__file__)))
from custom_mmpkg.custom_mmcv.utils import Config, DictAction
from custom_controlnet_aux.metric3d.mono.model.monodepth_model import get_configured_monodepth_model
from custom_controlnet_aux.metric3d.mono.utils.running import load_ckpt
from custom_controlnet_aux.metric3d.mono.utils.do_test import transform_test_data_scalecano, get_prediction
import numpy as np
from custom_controlnet_aux.metric3d.mono.utils.visualization import vis_surface_normal
from einops import repeat
from PIL import Image
from custom_controlnet_aux.util import HWC3, common_input_validate, resize_image_with_pad, custom_hf_download, METRIC3D_MODEL_NAME
import re
import matplotlib.pyplot as plt
def load_model(model_selection, model_path):
if model_selection == "vit-small":
cfg = Config.fromfile(CODE_SPACE / 'mono/configs/HourglassDecoder/vit.raft5.small.py')
elif model_selection == "vit-large":
cfg = Config.fromfile(CODE_SPACE / 'mono/configs/HourglassDecoder/vit.raft5.large.py')
elif model_selection == "vit-giant2":
cfg = Config.fromfile(CODE_SPACE / 'mono/configs/HourglassDecoder/vit.raft5.giant2.py')
else:
raise NotImplementedError(f"metric3d model: {model_selection}")
model = get_configured_monodepth_model(cfg, )
model, _, _, _ = load_ckpt(model_path, model, strict_match=False)
model.eval()
model = model
return model, cfg
def gray_to_colormap(img, cmap='rainbow'):
"""
Transfer gray map to matplotlib colormap
"""
assert img.ndim == 2
img[img<0] = 0
mask_invalid = img < 1e-10
img = img / (img.max() + 1e-8)
norm = plt.Normalize(vmin=0, vmax=1.1) # Use plt.Normalize instead of matplotlib.colors.Normalize
cmap_m = plt.get_cmap(cmap) # Access the colormap directly from plt
map = plt.cm.ScalarMappable(norm=norm, cmap=cmap_m)
colormap = (map.to_rgba(img)[:, :, :3] * 255).astype(np.uint8)
colormap[mask_invalid] = 0
return colormap
def predict_depth_normal(model, cfg, np_img, fx=1000.0, fy=1000.0, state_cache={}):
intrinsic = [fx, fy, np_img.shape[1]/2, np_img.shape[0]/2]
rgb_input, cam_models_stacks, pad, label_scale_factor = transform_test_data_scalecano(np_img, intrinsic, cfg.data_basic, device=next(model.parameters()).device)
with torch.no_grad():
pred_depth, confidence, output = get_prediction(
model = model,
input = rgb_input.unsqueeze(0),
cam_model = cam_models_stacks,
pad_info = pad,
scale_info = label_scale_factor,
gt_depth = None,
normalize_scale = cfg.data_basic.depth_range[1],
ori_shape=[np_img.shape[0], np_img.shape[1]],
)
pred_normal = output['normal_out_list'][0][:, :3, :, :]
H, W = pred_normal.shape[2:]
pred_normal = pred_normal[:, :, pad[0]:H-pad[1], pad[2]:W-pad[3]]
pred_depth = pred_depth[:, :, pad[0]:H-pad[1], pad[2]:W-pad[3] ]
pred_depth = pred_depth.squeeze().cpu().numpy()
pred_color = gray_to_colormap(pred_depth, 'Greys')
pred_normal = torch.nn.functional.interpolate(pred_normal, [np_img.shape[0], np_img.shape[1]], mode='bilinear').squeeze()
pred_normal = pred_normal.permute(1,2,0)
pred_color_normal = vis_surface_normal(pred_normal)
pred_normal = pred_normal.cpu().numpy()
# Storing depth and normal map in state for potential 3D reconstruction
state_cache['depth'] = pred_depth
state_cache['normal'] = pred_normal
state_cache['img'] = np_img
state_cache['intrinsic'] = intrinsic
state_cache['confidence'] = confidence
return pred_color, pred_color_normal, state_cache
class Metric3DDetector:
def __init__(self, model, cfg):
self.model = model
self.cfg = cfg
self.device = "cpu"
@classmethod
def from_pretrained(cls, pretrained_model_or_path=METRIC3D_MODEL_NAME, filename="metric_depth_vit_small_800k.pth"):
model_path = custom_hf_download(pretrained_model_or_path, filename)
backbone = re.findall(r"metric_depth_vit_(\w+)_", model_path)[0]
model, cfg = load_model(f'vit-{backbone}', model_path)
return cls(model, cfg)
def to(self, device):
self.model.to(device)
self.device = device
return self
def __call__(self, input_image, detect_resolution=512, fx=1000, fy=1000, output_type=None, upscale_method="INTER_CUBIC", depth_and_normal=True, **kwargs):
input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
depth_map, normal_map, _ = predict_depth_normal(self.model, self.cfg, input_image, fx=fx, fy=fy)
# ControlNet uses inverse depth and normal
depth_map, normal_map = depth_map, 255 - normal_map
depth_map, remove_pad = resize_image_with_pad(depth_map, detect_resolution, upscale_method)
normal_map, _ = resize_image_with_pad(normal_map, detect_resolution, upscale_method)
depth_map, normal_map = remove_pad(depth_map), remove_pad(normal_map)
if output_type == "pil":
depth_map = Image.fromarray(depth_map)
normal_map = Image.fromarray(normal_map)
if depth_and_normal:
return depth_map, normal_map
else:
return depth_map