File size: 4,472 Bytes
201ab98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# Reference https://github.com/DepthAnything/Depth-Anything-V2/metric_depth
import os
import sys
from typing import *
from pathlib import Path

import click
import torch
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.transforms.functional as TF
import cv2

from moge.test.baseline import MGEBaselineInterface


class Baseline(MGEBaselineInterface):

    def __init__(self, repo_path: str, backbone: str, domain: str, num_tokens: int, device: str):
        device = torch.device(device)
        repo_path = os.path.abspath(repo_path)
        if not Path(repo_path).exists():
            raise FileNotFoundError(f'Cannot find the Depth-Anything repository at {repo_path}. Please clone the repository and provide the path to it using the --repo option.')
        sys.path.append(os.path.join(repo_path, 'metric_depth'))
        from depth_anything_v2.dpt import DepthAnythingV2

        model_configs = {
            'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
            'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
            'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}
        }

        if domain == 'indoor':
            dataset = 'hypersim'
            max_depth = 20
        elif domain == 'outdoor':
            dataset = 'vkitti'
            max_depth = 80
        else:
            raise ValueError(f"Invalid domain: {domain}")

        model = DepthAnythingV2(**model_configs[backbone], max_depth=max_depth)
        checkpoint_path = os.path.join(repo_path, f'checkpoints/depth_anything_v2_metric_{dataset}_{backbone}.pth')
        if not os.path.exists(checkpoint_path):
            raise FileNotFoundError(f'Cannot find the checkpoint file at {checkpoint_path}. Please download the checkpoint file and place it in the checkpoints directory.')
        model.load_state_dict(torch.load(checkpoint_path, map_location='cpu', weights_only=True))
        model.eval().to(device)
                
        self.model = model
        self.num_tokens = num_tokens
        self.device = device

    @click.command()
    @click.option('--repo', 'repo_path', type=click.Path(), default='../Depth-Anything-V2', help='Path to the Depth-Anything repository.')
    @click.option('--backbone', type=click.Choice(['vits', 'vitb', 'vitl']), default='vitl', help='Backbone architecture.')
    @click.option('--domain', type=click.Choice(['indoor', 'outdoor']), help='Domain of the dataset.')
    @click.option('--num_tokens', type=int, default=None, help='Number of tokens for the ViT model')
    @click.option('--device', type=str, default='cuda', help='Device to use for inference.')
    @staticmethod
    def load(repo_path: str, backbone: str, domain: str, num_tokens: int, device: str):  
        return Baseline(repo_path, backbone, domain, num_tokens, device)

    @torch.inference_mode()
    def infer(self, image: torch.Tensor, intrinsics: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
        original_height, original_width = image.shape[-2:]

        assert intrinsics is None, "Depth-Anything-V2 does not support camera intrinsics input"

        if image.ndim == 3:
            image = image.unsqueeze(0)
            omit_batch_dim = True
        else:
            omit_batch_dim = False

        if self.num_tokens is None:
            resize_factor = 518 / min(original_height, original_width)
            expected_width = round(original_width * resize_factor / 14) * 14
            expected_height = round(original_height * resize_factor / 14) * 14
        else:
            aspect_ratio = original_width / original_height
            tokens_rows = round((self.num_tokens * aspect_ratio) ** 0.5)
            tokens_cols = round((self.num_tokens / aspect_ratio) ** 0.5)
            expected_width = tokens_cols * 14
            expected_height = tokens_rows * 14
        image = TF.resize(image, (expected_height, expected_width), interpolation=T.InterpolationMode.BICUBIC, antialias=True)
    
        image = TF.normalize(image, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

        depth = self.model(image)

        depth = F.interpolate(depth[:, None], size=(original_height, original_width), mode='bilinear', align_corners=False, antialias=False)[:, 0]

        if omit_batch_dim:
            depth = depth.squeeze(0)

        return {
            'depth_metric': depth
        }