File size: 5,024 Bytes
201ab98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# Reference: https://github.com/YvanYin/Metric3D
import os
import sys
from typing import *

import click
import torch
import torch.nn.functional as F
import cv2

from moge.test.baseline import MGEBaselineInterface


class Baseline(MGEBaselineInterface):
    def __init__(self, backbone: Literal['vits', 'vitl', 'vitg'], device):
        backbone_map = {
            'vits': 'metric3d_vit_small',
            'vitl': 'metric3d_vit_large',
            'vitg': 'metric3d_vit_giant2'
        }

        device = torch.device(device)
        model = torch.hub.load('yvanyin/metric3d', backbone_map[backbone], pretrain=True)
        model.to(device).eval()

        self.model = model
        self.device = device

    @click.command()
    @click.option('--backbone', type=click.Choice(['vits', 'vitl', 'vitg']), default='vitl', help='Encoder architecture.')
    @click.option('--device', type=str, default='cuda', help='Device to use.')
    @staticmethod
    def load(backbone: str = 'vitl', device: torch.device = 'cuda'):
        return Baseline(backbone, device)

    @torch.inference_mode()
    def inference_one_image(self, image: torch.Tensor, intrinsics: torch.Tensor = None):
        # Reference: https://github.com/YvanYin/Metric3D/blob/main/mono/utils/do_test.py

        # rgb_origin: RGB, 0-255, uint8
        rgb_origin = image.cpu().numpy().transpose((1, 2, 0)) * 255

        # keep ratio resize
        input_size = (616, 1064) # for vit model
        h, w = rgb_origin.shape[:2]
        scale = min(input_size[0] / h, input_size[1] / w)
        rgb = cv2.resize(rgb_origin, (int(w * scale), int(h * scale)), interpolation=cv2.INTER_LINEAR)
        if intrinsics is not None:
            focal = intrinsics[0, 0] * int(w * scale)
        
        # padding to input_size
        padding = [123.675, 116.28, 103.53]
        h, w = rgb.shape[:2]
        pad_h = input_size[0] - h
        pad_w = input_size[1] - w
        pad_h_half = pad_h // 2
        pad_w_half = pad_w // 2
        rgb = cv2.copyMakeBorder(rgb, pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half, cv2.BORDER_CONSTANT, value=padding)
        pad_info = [pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half]

        # normalize rgb
        mean = torch.tensor([123.675, 116.28, 103.53]).float()[:, None, None]
        std = torch.tensor([58.395, 57.12, 57.375]).float()[:, None, None]
        rgb = torch.from_numpy(rgb.transpose((2, 0, 1))).float()
        rgb = torch.div((rgb - mean), std)
        rgb = rgb[None, :, :, :].cuda()

        # inference 
        pred_depth, confidence, output_dict = self.model.inference({'input': rgb})

        # un pad
        pred_depth = pred_depth.squeeze()
        pred_depth = pred_depth[pad_info[0] : pred_depth.shape[0] - pad_info[1], pad_info[2] : pred_depth.shape[1] - pad_info[3]]
        pred_depth = pred_depth.clamp_min(0.5)  # clamp to 0.5m, since metric3d could yield very small depth values, resulting in crashed the scale shift alignment.
        
        # upsample to original size
        pred_depth = F.interpolate(pred_depth[None, None, :, :], image.shape[-2:], mode='bilinear').squeeze()
        
        if intrinsics is not None:
            # de-canonical transform
            canonical_to_real_scale = focal / 1000.0 # 1000.0 is the focal length of canonical camera
            pred_depth = pred_depth * canonical_to_real_scale # now the depth is metric
            pred_depth = torch.clamp(pred_depth, 0, 300)

        pred_normal, normal_confidence = output_dict['prediction_normal'].split([3, 1], dim=1) # see https://arxiv.org/abs/2109.09881 for details

        # un pad and resize to some size if needed
        pred_normal = pred_normal.squeeze(0)
        pred_normal = pred_normal[:, pad_info[0] : pred_normal.shape[1] - pad_info[1], pad_info[2] : pred_normal.shape[2] - pad_info[3]]

        # you can now do anything with the normal
        pred_normal = F.interpolate(pred_normal[None, :, :, :], image.shape[-2:], mode='bilinear').squeeze(0)
        pred_normal = F.normalize(pred_normal, p=2, dim=0)

        return pred_depth, pred_normal.permute(1, 2, 0)

    @torch.inference_mode()
    def infer(self, image: torch.Tensor, intrinsics: torch.Tensor = None):    
        # image: (B, H, W, 3) or (H, W, 3)
        if image.ndim == 3:
            pred_depth, pred_normal = self.inference_one_image(image, intrinsics)
        else:
            for i in range(image.shape[0]):
                pred_depth_i, pred_normal_i = self.inference_one_image(image[i], intrinsics[i] if intrinsics is not None else None)
                pred_depth.append(pred_depth_i)
                pred_normal.append(pred_normal_i)
            pred_depth = torch.stack(pred_depth, dim=0)
            pred_normal = torch.stack(pred_normal, dim=0)
        
        if intrinsics is not None:
            return {
                "depth_metric": pred_depth,
            }
        else:
            return {
                "depth_scale_invariant": pred_depth,
            }