File size: 6,269 Bytes
57746f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import torch
import torch.nn as nn
import yaml
import sys
sys.path.append(".")
sys.path.append("submodules")
sys.path.append("submodules/mast3r")
from mast3r.model import AsymmetricMASt3R
from src.ptv3 import PTV3
from src.gaussian_head import GaussianHead
from src.utils.points_process import merge_points
from src.losses import GaussianLoss
from src.lseg import LSegFeatureExtractor
import argparse

class LSM_MASt3R(nn.Module):
    def __init__(self, 
                 mast3r_config, 
                 point_transformer_config, 
                 gaussian_head_config, 
                 lseg_config, 
                 ):

        super().__init__()
        # self.config
        self.config = {
            'mast3r_config': mast3r_config,
            'point_transformer_config': point_transformer_config,
            'gaussian_head_config': gaussian_head_config,
            'lseg_config': lseg_config
        }
        
        # Initialize AsymmetricMASt3R
        self.mast3r = AsymmetricMASt3R.from_pretrained(**mast3r_config)
        
        # Freeze MASt3R parameters
        for param in self.mast3r.parameters():
            param.requires_grad = False
        self.mast3r.eval()
        
        # Initialize PointTransformerV3
        self.point_transformer = PTV3(**point_transformer_config)
        
        # Initialize the gaussian head
        self.gaussian_head = GaussianHead(**gaussian_head_config)
        
        # Initialize the lseg feature extractor
        self.lseg_feature_extractor = LSegFeatureExtractor.from_pretrained(**lseg_config)
        for param in self.lseg_feature_extractor.parameters():
            param.requires_grad = False
        self.lseg_feature_extractor.eval()

        # Define two linear layers
        d_gs_feats = gaussian_head_config.get('d_gs_feats', 32)
        self.feature_reduction = nn.Sequential(
            nn.Conv2d(512, d_gs_feats, kernel_size=1),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        ) # (b, 512, h//2, w//2) -> (b, d_features, h, w)
        
        self.feature_expansion = nn.Sequential(
            nn.Conv2d(d_gs_feats, 512, kernel_size=1),
            nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=True)
        ) # (b, d_features, h, w) -> (b, 512, h//2, w//2)

    def forward(self, view1, view2):
        # AsymmetricMASt3R forward pass
        mast3r_output = self.mast3r(view1, view2)
        
        # merge points from two views
        data_dict = merge_points(mast3r_output, view1, view2)
        
        # PointTransformerV3 forward pass
        point_transformer_output = self.point_transformer(data_dict)
        
        # extract lseg features
        lseg_features = self.extract_lseg_features(view1, view2)
        
        # Gaussian head forward pass
        final_output = self.gaussian_head(point_transformer_output, lseg_features)
        
        return final_output
    
    def extract_lseg_features(self, view1, view2):
        # concat view1 and view2
        img = torch.cat([view1['img'], view2['img']], dim=0) # (v*b, 3, h, w)
        # extract features
        lseg_features = self.lseg_feature_extractor.extract_features(img) # (v*b, 512, h//2, w//2)
        # reduce dimensions
        lseg_features = self.feature_reduction(lseg_features) # (v*b, d_features, h, w)

        return lseg_features
    
    @staticmethod
    def from_pretrained(checkpoint_path, device='cuda'):
        # Load the checkpoint
        ckpt = torch.load(checkpoint_path, map_location='cpu')
        
        # Extract the configuration from the checkpoint
        config = ckpt['args']
        
        # Create a new instance of LSM_MASt3R
        model = eval(config.model)
        
        # Load the state dict
        model.load_state_dict(ckpt['model'])
        
        # Move the model to the specified device
        model = model.to(device)
        
        return model

    def state_dict(self, destination=None, prefix='', keep_vars=False):
        # 获取所有参数的state_dict
        full_state_dict = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
        
        # 只保留需要训练的参数
        trainable_state_dict = {
            k: v for k, v in full_state_dict.items()
            if not (k.startswith('mast3r.') or k.startswith('lseg_feature_extractor.'))
        }
        
        return trainable_state_dict

    def load_state_dict(self, state_dict, strict=True):
        # 获取当前模型的完整state_dict
        model_state = super().state_dict()
        
        # 只更新需要训练的参数
        for k in list(state_dict.keys()):
            if k in model_state and not (k.startswith('mast3r.') or k.startswith('lseg_feature_extractor.')):
                model_state[k] = state_dict[k]
        
        # 使用更新后的state_dict加载模型
        super().load_state_dict(model_state, strict=False)

if __name__ == "__main__":
    from torch.utils.data import DataLoader
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--checkpoint', type=str)
    args = parser.parse_args()

    # Load config
    with open("configs/model_config.yaml", "r") as f:
        config = yaml.safe_load(f)
    # Initialize model
    if args.checkpoint is not None:
        model = LSM_MASt3R.from_pretrained(args.checkpoint, device='cuda')
    else:
        model = LSM_MASt3R(**config).to('cuda')
    
    model.eval()

    # Print model
    print(model)
    # Load dataset
    from src.datasets.scannet import Scannet
    dataset = Scannet(split='train', ROOT="data/scannet_processed", resolution=[(512, 384)])
    # Print dataset
    print(dataset)
    # Test model
    data_loader = DataLoader(dataset, batch_size=3, shuffle=True)
    data = next(iter(data_loader))
    # move data to cuda
    for view in data:
        view['img'] = view['img'].to('cuda')
        view['depthmap'] = view['depthmap'].to('cuda')
        view['camera_pose'] = view['camera_pose'].to('cuda')
        view['camera_intrinsics'] = view['camera_intrinsics'].to('cuda')
    # Forward pass
    output = model(*data[:2])
    
    # Loss
    loss = GaussianLoss()
    loss_value = loss(*data, *output, model)
    print(loss_value)