- dust3r/cloud_opt_flow/optimizer.py +4 -3
 - requirements.txt +1 -1
 
    	
        dust3r/cloud_opt_flow/optimizer.py
    CHANGED
    
    | 
         @@ -11,9 +11,9 @@ from dust3r.utils.geometry import xy_grid, geotrf, depthmap_to_pts3d 
     | 
|
| 11 | 
         
             
            from dust3r.utils.device import to_cpu, to_numpy
         
     | 
| 12 | 
         
             
            from dust3r.utils.goem_opt import DepthBasedWarping, OccMask, WarpImage, depth_regularization_si_weighted, tum_to_pose_matrix
         
     | 
| 13 | 
         
             
            from third_party.raft import load_RAFT
         
     | 
| 14 | 
         
            -
            from sam2.build_sam import build_sam2_video_predictor
         
     | 
| 15 | 
         
            -
            sam2_checkpoint = "third_party/sam2/checkpoints/sam2.1_hiera_large.pt"
         
     | 
| 16 | 
         
            -
            model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
         
     | 
| 17 | 
         | 
| 18 | 
         
             
            def smooth_L1_loss_fn(estimate, gt, mask, beta=1.0, per_pixel_thre=50.):
         
     | 
| 19 | 
         
             
                loss_raw_shape = F.smooth_l1_loss(estimate*mask, gt*mask, beta=beta, reduction='none')
         
     | 
| 
         @@ -109,6 +109,7 @@ class PointCloudOptimizer(BasePCOptimizer): 
     | 
|
| 109 | 
         
             
                        self.flow_ji.requires_grad_(False)
         
     | 
| 110 | 
         
             
                        self.flow_valid_mask_i.requires_grad_(False)
         
     | 
| 111 | 
         
             
                        self.flow_valid_mask_j.requires_grad_(False)
         
     | 
| 
         | 
|
| 112 | 
         
             
                        if sam2_mask_refine: 
         
     | 
| 113 | 
         
             
                            with torch.no_grad():
         
     | 
| 114 | 
         
             
                                self.refine_motion_mask_w_sam2()
         
     | 
| 
         | 
|
| 11 | 
         
             
            from dust3r.utils.device import to_cpu, to_numpy
         
     | 
| 12 | 
         
             
            from dust3r.utils.goem_opt import DepthBasedWarping, OccMask, WarpImage, depth_regularization_si_weighted, tum_to_pose_matrix
         
     | 
| 13 | 
         
             
            from third_party.raft import load_RAFT
         
     | 
| 14 | 
         
            +
            # from sam2.build_sam import build_sam2_video_predictor
         
     | 
| 15 | 
         
            +
            # sam2_checkpoint = "third_party/sam2/checkpoints/sam2.1_hiera_large.pt"
         
     | 
| 16 | 
         
            +
            # model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
         
     | 
| 17 | 
         | 
| 18 | 
         
             
            def smooth_L1_loss_fn(estimate, gt, mask, beta=1.0, per_pixel_thre=50.):
         
     | 
| 19 | 
         
             
                loss_raw_shape = F.smooth_l1_loss(estimate*mask, gt*mask, beta=beta, reduction='none')
         
     | 
| 
         | 
|
| 109 | 
         
             
                        self.flow_ji.requires_grad_(False)
         
     | 
| 110 | 
         
             
                        self.flow_valid_mask_i.requires_grad_(False)
         
     | 
| 111 | 
         
             
                        self.flow_valid_mask_j.requires_grad_(False)
         
     | 
| 112 | 
         
            +
                        sam2_mask_refine = False
         
     | 
| 113 | 
         
             
                        if sam2_mask_refine: 
         
     | 
| 114 | 
         
             
                            with torch.no_grad():
         
     | 
| 115 | 
         
             
                                self.refine_motion_mask_w_sam2()
         
     | 
    	
        requirements.txt
    CHANGED
    
    | 
         @@ -19,4 +19,4 @@ seaborn 
     | 
|
| 19 | 
         
             
            evo
         
     | 
| 20 | 
         
             
            transformers
         
     | 
| 21 | 
         
             
            git+https://github.com/apple/ml-depth-pro.git
         
     | 
| 22 | 
         
            -
            git+https://github.com/facebookresearch/sam2.git
         
     | 
| 
         | 
|
| 19 | 
         
             
            evo
         
     | 
| 20 | 
         
             
            transformers
         
     | 
| 21 | 
         
             
            git+https://github.com/apple/ml-depth-pro.git
         
     | 
| 22 | 
         
            +
            # git+https://github.com/facebookresearch/sam2.git
         
     |