| 
							 | 
						import unittest | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						import torch | 
					
					
						
						| 
							 | 
						import torch.nn as nn | 
					
					
						
						| 
							 | 
						import torch.nn.functional as F | 
					
					
						
						| 
							 | 
						from sam2.modeling.sam2_base import SAM2Base | 
					
					
						
						| 
							 | 
						from sam2.build_sam import build_sam2_video_predictor | 
					
					
						
						| 
							 | 
						from sam2.utils.transforms import SAM2Transforms | 
					
					
						
						| 
							 | 
						from sam2.modeling.sam2_utils import get_1d_sine_pe | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class SAM2VideoTrainer(nn.Module): | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    SAM2VideoTrainer is a PyTorch module for training a video segmentation model using SAM2. | 
					
					
						
						| 
							 | 
						    Attributes: | 
					
					
						
						| 
							 | 
						        device (torch.device): The device to run the model on. | 
					
					
						
						| 
							 | 
						        model (nn.Module): The SAM2 video predictor model. | 
					
					
						
						| 
							 | 
						        num_feature_levels (int): Number of feature levels in the model. | 
					
					
						
						| 
							 | 
						        memory_size (int): Size of the memory for storing features. | 
					
					
						
						| 
							 | 
						        _transforms (SAM2Transforms): Transformations applied to the input data. | 
					
					
						
						| 
							 | 
						        _bb_feat_sizes (list): Spatial dimensions for backbone feature maps. | 
					
					
						
						| 
							 | 
						        num_maskmem (int): Number of mask memory features. | 
					
					
						
						| 
							 | 
						        sam_point_coords (torch.Tensor): Placeholder for SAM point coordinates. | 
					
					
						
						| 
							 | 
						        sam_point_labels (torch.Tensor): Placeholder for SAM point labels. | 
					
					
						
						| 
							 | 
						        _orig_hw (list): Original height and width of the input frames. | 
					
					
						
						| 
							 | 
						        maskmem_features (list): List of mask memory features. | 
					
					
						
						| 
							 | 
						        maskmem_pos_enc (list): List of mask memory positional encodings. | 
					
					
						
						| 
							 | 
						        batch_size (int): Batch size of the input data. | 
					
					
						
						| 
							 | 
						        obj_ptrs (list): List of object pointers. | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def __init__(self, model_cfg, sam2_checkpoint, device, memory_size=7, mask_threshold=0.5, use_mask_threshold=False): | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        Initializes the SAM2VideoTrainer class. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        Args: | 
					
					
						
						| 
							 | 
						            model_cfg (dict): Configuration dictionary for the model. | 
					
					
						
						| 
							 | 
						            sam2_checkpoint (str): Path to the SAM2 checkpoint file. | 
					
					
						
						| 
							 | 
						            device (torch.device): The device to run the model on (e.g., 'cpu' or 'cuda'). | 
					
					
						
						| 
							 | 
						            memory_size (int, optional): Size of the memory. Defaults to 7. | 
					
					
						
						| 
							 | 
						            mask_threshold (float, optional): Threshold for mask prediction. Defaults to 0.5. | 
					
					
						
						| 
							 | 
						            use_mask_threshold (bool, optional): Flag to use mask thresholding. Defaults to False. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        Attributes: | 
					
					
						
						| 
							 | 
						            device (torch.device): The device to run the model on. | 
					
					
						
						| 
							 | 
						            model (SAM2VideoPredictor): The SAM2 video predictor model. | 
					
					
						
						| 
							 | 
						            num_feature_levels (int): Number of feature levels in the model. | 
					
					
						
						| 
							 | 
						            memory_size (int): Size of the memory. | 
					
					
						
						| 
							 | 
						            _transforms (SAM2Transforms): Transformations applied to the input data. | 
					
					
						
						| 
							 | 
						            _bb_feat_sizes (list): Spatial dimensions for backbone feature maps. | 
					
					
						
						| 
							 | 
						            num_maskmem (int): Number of mask memories. | 
					
					
						
						| 
							 | 
						            sam_point_coords (torch.Tensor): Tensor for SAM point coordinates. | 
					
					
						
						| 
							 | 
						            sam_point_labels (torch.Tensor): Tensor for SAM point labels. | 
					
					
						
						| 
							 | 
						            mask_threshold (float): Threshold for mask prediction. | 
					
					
						
						| 
							 | 
						            use_mask_threshold (bool): Flag to use mask thresholding. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        super().__init__() | 
					
					
						
						| 
							 | 
						        self.device = device | 
					
					
						
						| 
							 | 
						        self.model = build_sam2_video_predictor( | 
					
					
						
						| 
							 | 
						            model_cfg, sam2_checkpoint, device=self.device, mode="train" | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        self.model.train() | 
					
					
						
						| 
							 | 
						        self.num_feature_levels = self.model.num_feature_levels | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.num_maskmem = 7 | 
					
					
						
						| 
							 | 
						        self.memory_size = ( | 
					
					
						
						| 
							 | 
						            memory_size if memory_size <= self.num_maskmem else self.num_maskmem | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self._transforms = SAM2Transforms( | 
					
					
						
						| 
							 | 
						            resolution=self.model.image_size, | 
					
					
						
						| 
							 | 
						            mask_threshold=0.5, | 
					
					
						
						| 
							 | 
						            max_hole_area=0, | 
					
					
						
						| 
							 | 
						            max_sprinkle_area=0, | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        self._bb_feat_sizes = [ | 
					
					
						
						| 
							 | 
						            (256, 256), | 
					
					
						
						| 
							 | 
						            (128, 128), | 
					
					
						
						| 
							 | 
						            (64, 64), | 
					
					
						
						| 
							 | 
						        ] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.sam_point_coords = torch.zeros(1, 1, 2, device=device) | 
					
					
						
						| 
							 | 
						        self.sam_point_labels = -torch.ones(1, 1, dtype=torch.int32, device=device) | 
					
					
						
						| 
							 | 
						        self.mask_threshold = mask_threshold | 
					
					
						
						| 
							 | 
						        self.use_mask_threshold = use_mask_threshold | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.init_state() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def init_state(self): | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        Initializes the state variables for the video trainer. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        This method sets the initial state of various attributes used in the video | 
					
					
						
						| 
							 | 
						        training process. It resets the original height and width, mask memory | 
					
					
						
						| 
							 | 
						        features, mask memory positional encoding, batch size, and object pointers | 
					
					
						
						| 
							 | 
						        to their default values. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        Attributes: | 
					
					
						
						| 
							 | 
						            _orig_hw (tuple or None): Original height and width of the video frames. | 
					
					
						
						| 
							 | 
						            maskmem_features (Any or None): Features related to mask memory. | 
					
					
						
						| 
							 | 
						            maskmem_pos_enc (Any or None): Positional encoding for mask memory. | 
					
					
						
						| 
							 | 
						            batch_size (int or None): Size of the batch for training. | 
					
					
						
						| 
							 | 
						            obj_ptrs (list): List of object pointers used in the training process. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        self._orig_hw = None | 
					
					
						
						| 
							 | 
						        self.maskmem_features = None | 
					
					
						
						| 
							 | 
						        self.maskmem_pos_enc = None | 
					
					
						
						| 
							 | 
						        self.batch_size = None | 
					
					
						
						| 
							 | 
						        self.current_frame_idx = 0 | 
					
					
						
						| 
							 | 
						        self.obj_ptrs = [] | 
					
					
						
						| 
							 | 
						        self.num_frames = 0 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def reset_state(self): | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        Resets the state of the video trainer. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        This method clears the internal state variables, setting them to their initial values: | 
					
					
						
						| 
							 | 
						        - `_orig_hw`: Set to None. Represents the original height and width. | 
					
					
						
						| 
							 | 
						        - `maskmem_features`: Set to None. Represents the mask memory features. | 
					
					
						
						| 
							 | 
						        - `maskmem_pos_enc`: Set to None. Represents the mask memory positional encoding. | 
					
					
						
						| 
							 | 
						        - `batch_size`: Set to None. Represents the batch size. | 
					
					
						
						| 
							 | 
						        - `obj_ptrs`: Set to an empty list. Represents the object pointers. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        self._orig_hw = None | 
					
					
						
						| 
							 | 
						        self.maskmem_features = None | 
					
					
						
						| 
							 | 
						        self.maskmem_pos_enc = None | 
					
					
						
						| 
							 | 
						        self.batch_size = None | 
					
					
						
						| 
							 | 
						        self.current_frame_idx = 0 | 
					
					
						
						| 
							 | 
						        self.obj_ptrs = [] | 
					
					
						
						| 
							 | 
						        self.num_frames = 0 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def forward(self, videos, bboxes, labels=None): | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        Forward pass for processing video frames and predicting masks, logits, and IoUs. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        Args: | 
					
					
						
						| 
							 | 
						            videos (torch.Tensor): A tensor of shape (batch_size, num_frames, C, H, W) representing the input video frames. | 
					
					
						
						| 
							 | 
						            bboxes (torch.Tensor): A tensor of shape (batch_size, 4) representing the bounding boxes for the first frame. | 
					
					
						
						| 
							 | 
						            labels (torch.Tensor, optional): A tensor of shape (batch_size, num_frames, H, W) representing the ground truth masks for each frame. Defaults to None. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        Returns: | 
					
					
						
						| 
							 | 
						            tuple: A tuple containing: | 
					
					
						
						| 
							 | 
						            - all_masks (list of torch.Tensor): A list of tensors representing the predicted masks for each frame. | 
					
					
						
						| 
							 | 
						            - all_logits (list of torch.Tensor): A list of tensors representing the predicted logits for each frame. | 
					
					
						
						| 
							 | 
						            - all_ious (list of torch.Tensor): A list of tensors representing the predicted IoUs for each frame. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        self.init_state() | 
					
					
						
						| 
							 | 
						        batch_size, num_frames, C, H, W = videos.shape | 
					
					
						
						| 
							 | 
						        self.num_frames = num_frames | 
					
					
						
						| 
							 | 
						        self._orig_hw = [H, W] | 
					
					
						
						| 
							 | 
						        self.batch_size = batch_size | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        videos = videos.view(batch_size * num_frames, C, H, W) | 
					
					
						
						| 
							 | 
						        features = self.model.forward_image(videos)   | 
					
					
						
						| 
							 | 
						        features = { | 
					
					
						
						| 
							 | 
						            k: ( | 
					
					
						
						| 
							 | 
						                v.view(batch_size, num_frames, *v.shape[1:]) | 
					
					
						
						| 
							 | 
						                if not isinstance(v, list) | 
					
					
						
						| 
							 | 
						                else ([_v.view(batch_size, num_frames, *_v.shape[1:]) for _v in v]) | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						            for k, v in features.items() | 
					
					
						
						| 
							 | 
						        } | 
					
					
						
						| 
							 | 
						        frame_features = self.preprocess_frame_features( | 
					
					
						
						| 
							 | 
						            features, batch_size, num_frames | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        first_frame_features = frame_features[0] | 
					
					
						
						| 
							 | 
						        first_frame_bbox = bboxes.view(batch_size, 4) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        first_frame_masks, first_frame_logits, first_frame_ious, object_score_logits = ( | 
					
					
						
						| 
							 | 
						            self._predict_first_frame(first_frame_features, first_frame_bbox) | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        prev_pred_mask = first_frame_masks if labels is None else labels[:, 0] | 
					
					
						
						| 
							 | 
						        memory = self._initialize_memory(first_frame_features, prev_pred_mask, object_score_logits) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        all_masks, all_logits, all_ious = ( | 
					
					
						
						| 
							 | 
						            [first_frame_masks], | 
					
					
						
						| 
							 | 
						            [first_frame_logits], | 
					
					
						
						| 
							 | 
						            [first_frame_ious], | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        for t in range(1, num_frames): | 
					
					
						
						| 
							 | 
						            self.current_frame_idx = t | 
					
					
						
						| 
							 | 
						            frame_feature = frame_features[t] | 
					
					
						
						| 
							 | 
						            masks, logits, ious, object_score_logits = self._predict_frame( | 
					
					
						
						| 
							 | 
						                frame_feature, memory, prev_pred_mask | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						            all_masks.append(masks) | 
					
					
						
						| 
							 | 
						            all_logits.append(logits) | 
					
					
						
						| 
							 | 
						            all_ious.append(ious) | 
					
					
						
						| 
							 | 
						            if t < num_frames - 1: | 
					
					
						
						| 
							 | 
						                prev_pred_mask = masks if labels is None else labels[:, t] | 
					
					
						
						| 
							 | 
						                memory = self._update_memory(frame_feature, prev_pred_mask, memory, object_score_logits) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.reset_state() | 
					
					
						
						| 
							 | 
						        return all_masks, all_logits, all_ious | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def normalize_bbox(self, bbox): | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        Normalize the given bounding box coordinates. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        This method transforms the bounding box coordinates to a normalized form | 
					
					
						
						| 
							 | 
						        based on the original height and width of the image. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        Args: | 
					
					
						
						| 
							 | 
						            bbox (list or ndarray): The bounding box coordinates to be normalized. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        Returns: | 
					
					
						
						| 
							 | 
						            list or ndarray: The normalized bounding box coordinates. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        unnorm_bbox = self._transforms.transform_boxes( | 
					
					
						
						| 
							 | 
						            bbox, normalize=True, orig_hw=self._orig_hw | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        return unnorm_bbox | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def _get_points_placeholder(self, batch_size=None): | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        Generates a placeholder for point coordinates and labels. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        Args: | 
					
					
						
						| 
							 | 
						            batch_size (int, optional): The size of the batch. If not provided, | 
					
					
						
						| 
							 | 
						                        defaults to the instance's batch_size attribute. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        Returns: | 
					
					
						
						| 
							 | 
						            tuple: A tuple containing: | 
					
					
						
						| 
							 | 
						            - torch.Tensor: Expanded point coordinates tensor of shape (batch_size, -1, -1). | 
					
					
						
						| 
							 | 
						            - torch.Tensor: Expanded point labels tensor of shape (batch_size, -1). | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        batch_size = self.batch_size if batch_size is None else batch_size | 
					
					
						
						| 
							 | 
						        points_placeholder = ( | 
					
					
						
						| 
							 | 
						            self.sam_point_coords.expand(batch_size, -1, -1), | 
					
					
						
						| 
							 | 
						            self.sam_point_labels.expand(batch_size, -1), | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        return points_placeholder | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def unbind_frame_features(self, frame_features, num_frames): | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        Unbind image features from the model. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        keys = frame_features.keys() | 
					
					
						
						| 
							 | 
						        unbinded_frame_features = [] | 
					
					
						
						| 
							 | 
						        for frame_idx in range(num_frames): | 
					
					
						
						| 
							 | 
						            frame_feature = {} | 
					
					
						
						| 
							 | 
						            for k in keys: | 
					
					
						
						| 
							 | 
						                frame_feature[k] = ( | 
					
					
						
						| 
							 | 
						                    frame_features[k][:, frame_idx] | 
					
					
						
						| 
							 | 
						                    if not isinstance(frame_features[k], list) | 
					
					
						
						| 
							 | 
						                    else [v[:, frame_idx] for v in frame_features[k]] | 
					
					
						
						| 
							 | 
						                ) | 
					
					
						
						| 
							 | 
						            unbinded_frame_features.append(frame_feature) | 
					
					
						
						| 
							 | 
						        return unbinded_frame_features | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def preprocess_frame_features(self, frame_features, batch_size, num_frames): | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        Preprocess frame features. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        frame_features = self.unbind_frame_features(frame_features, num_frames) | 
					
					
						
						| 
							 | 
						        preprocessed_frame_features = [] | 
					
					
						
						| 
							 | 
						        for frame_idx, frame_feature in enumerate(frame_features): | 
					
					
						
						| 
							 | 
						            feature_maps = frame_feature["backbone_fpn"][-self.num_feature_levels :] | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps] | 
					
					
						
						| 
							 | 
						            if ( | 
					
					
						
						| 
							 | 
						                frame_idx == 0 and self.model.directly_add_no_mem_embed | 
					
					
						
						| 
							 | 
						            ):   | 
					
					
						
						| 
							 | 
						                vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            feats = [ | 
					
					
						
						| 
							 | 
						                feat.permute(1, 2, 0).view(batch_size, -1, *feat_size) | 
					
					
						
						| 
							 | 
						                for feat, feat_size in zip( | 
					
					
						
						| 
							 | 
						                    vision_feats[::-1], self._bb_feat_sizes[::-1] | 
					
					
						
						| 
							 | 
						                ) | 
					
					
						
						| 
							 | 
						            ][::-1] | 
					
					
						
						| 
							 | 
						            _features = { | 
					
					
						
						| 
							 | 
						                "image_embed": feats[-1], | 
					
					
						
						| 
							 | 
						                "high_res_feats": feats[:-1], | 
					
					
						
						| 
							 | 
						                "backbone_fpn": frame_feature["backbone_fpn"][ | 
					
					
						
						| 
							 | 
						                    -self.num_feature_levels : | 
					
					
						
						| 
							 | 
						                ], | 
					
					
						
						| 
							 | 
						                "vision_pos_enc": frame_feature["vision_pos_enc"][ | 
					
					
						
						| 
							 | 
						                    -self.num_feature_levels : | 
					
					
						
						| 
							 | 
						                ], | 
					
					
						
						| 
							 | 
						            } | 
					
					
						
						| 
							 | 
						            preprocessed_frame_features.append(_features) | 
					
					
						
						| 
							 | 
						        return preprocessed_frame_features | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def _embed_bbox(self, bbox): | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        Embed bounding boxes. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        bbox = self.normalize_bbox(bbox) | 
					
					
						
						| 
							 | 
						        box_coords = bbox.reshape(-1, 2, 2) | 
					
					
						
						| 
							 | 
						        box_labels = torch.tensor([[2, 3]], dtype=torch.int, device=bbox.device) | 
					
					
						
						| 
							 | 
						        box_labels = box_labels.repeat(bbox.size(0), 1) | 
					
					
						
						| 
							 | 
						        concat_points = (box_coords, box_labels) | 
					
					
						
						| 
							 | 
						        sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder( | 
					
					
						
						| 
							 | 
						            points=concat_points, boxes=None, masks=None | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        return sparse_embeddings, dense_embeddings | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def _predict_first_frame(self, features, bbox): | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        Predict masks and IoUs for the first frame. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        sparse_embeddings, dense_embeddings = self._embed_bbox(bbox) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        low_res_masks, ious, sam_output_tokens, object_score_logits = ( | 
					
					
						
						| 
							 | 
						            self.model.sam_mask_decoder( | 
					
					
						
						| 
							 | 
						                image_embeddings=features["image_embed"], | 
					
					
						
						| 
							 | 
						                image_pe=self.model.sam_prompt_encoder.get_dense_pe(), | 
					
					
						
						| 
							 | 
						                sparse_prompt_embeddings=sparse_embeddings, | 
					
					
						
						| 
							 | 
						                dense_prompt_embeddings=dense_embeddings, | 
					
					
						
						| 
							 | 
						                multimask_output=False, | 
					
					
						
						| 
							 | 
						                repeat_image=False, | 
					
					
						
						| 
							 | 
						                high_res_features=features["high_res_feats"], | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        sam_output_token = sam_output_tokens[:, -1] | 
					
					
						
						| 
							 | 
						        obj_ptr = self.model.obj_ptr_proj(sam_output_token) | 
					
					
						
						| 
							 | 
						        self.obj_ptrs.append(obj_ptr) | 
					
					
						
						| 
							 | 
						        pred_mask, pred_logit = self._postprocess_masks(low_res_masks) | 
					
					
						
						| 
							 | 
						        return pred_mask, pred_logit, ious[:, -1], object_score_logits | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def _postprocess_masks(self, logits, size=None): | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        Perform post-processing on output masks. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        size = self._orig_hw if size is None else size | 
					
					
						
						| 
							 | 
						        logits = F.interpolate(logits, size, mode="bilinear", align_corners=False) | 
					
					
						
						| 
							 | 
						        logits = logits[:, -1].unsqueeze(1) | 
					
					
						
						| 
							 | 
						        masks = torch.sigmoid(logits) | 
					
					
						
						| 
							 | 
						        if self.use_mask_threshold: | 
					
					
						
						| 
							 | 
						            masks = (masks > self.mask_threshold).float() | 
					
					
						
						| 
							 | 
						        return masks, logits | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def _extract_memory_features(self, features, masks, object_score_logits): | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        Extracts memory features from the given features and masks. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        Args: | 
					
					
						
						| 
							 | 
						            features (dict): A dictionary containing feature maps from the backbone FPN. | 
					
					
						
						| 
							 | 
						            masks (Tensor): A tensor representing the masks to be used by the memory encoder. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        Returns: | 
					
					
						
						| 
							 | 
						            dict: A dictionary containing: | 
					
					
						
						| 
							 | 
						            - "vision_features" (Tensor): The vision features extracted and processed by the memory encoder. | 
					
					
						
						| 
							 | 
						            - "vision_pos_enc" (Tensor): The positional encoding of the vision features. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        pix_feat = features["backbone_fpn"][-1] | 
					
					
						
						| 
							 | 
						        maskmem_out = self.model.memory_encoder( | 
					
					
						
						| 
							 | 
						            pix_feat, masks, skip_mask_sigmoid=True   | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        maskmem_features = maskmem_out["vision_features"] | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if self.model.no_obj_embed_spatial is not None: | 
					
					
						
						| 
							 | 
						            is_obj_appearing = (object_score_logits > 0).float() | 
					
					
						
						| 
							 | 
						            maskmem_features += ( | 
					
					
						
						| 
							 | 
						                1 - is_obj_appearing[..., None, None] | 
					
					
						
						| 
							 | 
						            ) * self.model.no_obj_embed_spatial[..., None, None].expand( | 
					
					
						
						| 
							 | 
						                *maskmem_features.shape | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						        maskmem_features = maskmem_features.flatten(2).permute(2, 0, 1) | 
					
					
						
						| 
							 | 
						        maskmem_pos_enc = maskmem_out["vision_pos_enc"][-1].flatten(2).permute(2, 0, 1) | 
					
					
						
						| 
							 | 
						        return {"vision_features": maskmem_features, "vision_pos_enc": maskmem_pos_enc} | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def _initialize_memory(self, features, masks, object_score_logits): | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        Initialize memory for the first frame. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        maskmem_out = self._extract_memory_features(features, masks, object_score_logits) | 
					
					
						
						| 
							 | 
						        self.maskmem_features = [maskmem_out["vision_features"]] | 
					
					
						
						| 
							 | 
						        self.maskmem_pos_enc = [maskmem_out["vision_pos_enc"]] | 
					
					
						
						| 
							 | 
						        return self.maskmem_features, self.maskmem_pos_enc | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def _update_memory(self, features, masks, memory=None, object_score_logits=None): | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        Update memory with new frame data. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        if memory is None: | 
					
					
						
						| 
							 | 
						            maskmem_features, maskmem_pos_enc = ( | 
					
					
						
						| 
							 | 
						                self.maskmem_features, | 
					
					
						
						| 
							 | 
						                self.maskmem_pos_enc, | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            maskmem_features, maskmem_pos_enc = memory | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        maskmem_out = self._extract_memory_features(features, masks, object_score_logits) | 
					
					
						
						| 
							 | 
						        maskmem_features.append(maskmem_out["vision_features"]) | 
					
					
						
						| 
							 | 
						        maskmem_pos_enc.append(maskmem_out["vision_pos_enc"]) | 
					
					
						
						| 
							 | 
						        if len(maskmem_features) > self.memory_size: | 
					
					
						
						| 
							 | 
						            self.maskmem_features = maskmem_features[-self.memory_size :] | 
					
					
						
						| 
							 | 
						            self.maskmem_pos_enc = maskmem_pos_enc[-self.memory_size :] | 
					
					
						
						| 
							 | 
						        return maskmem_features, maskmem_pos_enc | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def _prepare_memory(self, memory): | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        Prepare memory for the current frame. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        if memory is None: | 
					
					
						
						| 
							 | 
						            maskmem_features, maskmem_pos_enc = ( | 
					
					
						
						| 
							 | 
						                self.maskmem_features, | 
					
					
						
						| 
							 | 
						                self.maskmem_pos_enc, | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            maskmem_features, maskmem_pos_enc = memory | 
					
					
						
						| 
							 | 
						        for idx in range(len(maskmem_pos_enc)): | 
					
					
						
						| 
							 | 
						            rel_pos = len(maskmem_pos_enc) - idx | 
					
					
						
						| 
							 | 
						            maskmem_pos_enc[idx] = ( | 
					
					
						
						| 
							 | 
						                maskmem_pos_enc[idx] + self.model.maskmem_tpos_enc[rel_pos - 1] | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						        obj_ptrs = torch.stack(self.obj_ptrs, dim=0) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if self.model.add_tpos_enc_to_obj_ptrs: | 
					
					
						
						| 
							 | 
						            max_obj_ptrs_in_encoder = self.num_frames | 
					
					
						
						| 
							 | 
						            pos_list = [self.current_frame_idx] | 
					
					
						
						| 
							 | 
						            t_diff_max = max_obj_ptrs_in_encoder - 1 | 
					
					
						
						| 
							 | 
						            tpos_dim = self.model.hidden_dim if self.model.proj_tpos_enc_in_obj_ptrs else self.model.mem_dim | 
					
					
						
						| 
							 | 
						            obj_pos = torch.tensor(pos_list, device=obj_ptrs.device) | 
					
					
						
						| 
							 | 
						            obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim) | 
					
					
						
						| 
							 | 
						            obj_pos = self.model.obj_ptr_tpos_proj(obj_pos) | 
					
					
						
						| 
							 | 
						            obj_pos = obj_pos.unsqueeze(1).expand(-1, self.batch_size, self.model.mem_dim) | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            obj_pos = obj_ptrs.new_zeros( | 
					
					
						
						| 
							 | 
						                len(self.obj_ptrs), self.batch_size, self.model.mem_dim | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						        C = self.model.hidden_dim | 
					
					
						
						| 
							 | 
						        if self.model.mem_dim < C: | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            obj_ptrs = obj_ptrs.reshape( | 
					
					
						
						| 
							 | 
						                -1, self.batch_size, C // self.model.mem_dim, self.model.mem_dim | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						            obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1) | 
					
					
						
						| 
							 | 
						            obj_pos = obj_pos.repeat_interleave(C // self.model.mem_dim, dim=0) | 
					
					
						
						| 
							 | 
						        num_obj_ptr_tokens = obj_ptrs.shape[0] | 
					
					
						
						| 
							 | 
						        memory = torch.cat(maskmem_features + [obj_ptrs], dim=0) | 
					
					
						
						| 
							 | 
						        memory_pos_embed = torch.cat(maskmem_pos_enc + [obj_pos], dim=0) | 
					
					
						
						| 
							 | 
						        return memory, memory_pos_embed, num_obj_ptr_tokens | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def _predict_frame(self, features, memory, prev_mask=None): | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        Predict masks and IoUs for subsequent frames using memory. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        memory, memory_pos_embed, num_obj_ptr_tokens = self._prepare_memory(memory) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        current_vision_feats = [ | 
					
					
						
						| 
							 | 
						            x.flatten(2).permute(2, 0, 1) for x in features["backbone_fpn"] | 
					
					
						
						| 
							 | 
						        ] | 
					
					
						
						| 
							 | 
						        current_vision_pos_embeds = [ | 
					
					
						
						| 
							 | 
						            x.flatten(2).permute(2, 0, 1) for x in features["vision_pos_enc"] | 
					
					
						
						| 
							 | 
						        ] | 
					
					
						
						| 
							 | 
						        pix_feat_with_mem = self.model.memory_attention( | 
					
					
						
						| 
							 | 
						            curr=current_vision_feats[-1:], | 
					
					
						
						| 
							 | 
						            curr_pos=current_vision_pos_embeds[-1:], | 
					
					
						
						| 
							 | 
						            memory=memory, | 
					
					
						
						| 
							 | 
						            memory_pos=memory_pos_embed, | 
					
					
						
						| 
							 | 
						            num_obj_ptr_tokens=num_obj_ptr_tokens, | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view( | 
					
					
						
						| 
							 | 
						            *features["backbone_fpn"][-1].shape | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder( | 
					
					
						
						| 
							 | 
						            points=self._get_points_placeholder(), | 
					
					
						
						| 
							 | 
						            boxes=None, | 
					
					
						
						| 
							 | 
						            masks=None, | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        low_res_masks, ious, _, object_score_logits = self.model.sam_mask_decoder( | 
					
					
						
						| 
							 | 
						            image_embeddings=pix_feat_with_mem, | 
					
					
						
						| 
							 | 
						            image_pe=self.model.sam_prompt_encoder.get_dense_pe(), | 
					
					
						
						| 
							 | 
						            sparse_prompt_embeddings=sparse_embeddings, | 
					
					
						
						| 
							 | 
						            dense_prompt_embeddings=dense_embeddings, | 
					
					
						
						| 
							 | 
						            multimask_output=False, | 
					
					
						
						| 
							 | 
						            repeat_image=False, | 
					
					
						
						| 
							 | 
						            high_res_features=features["high_res_feats"], | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        pred_mask, pred_logit = self._postprocess_masks(low_res_masks) | 
					
					
						
						| 
							 | 
						        return pred_mask, pred_logit, ious[:, -1], object_score_logits | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class TestSAM2VideoTrainer(unittest.TestCase): | 
					
					
						
						| 
							 | 
						    def setUp(self): | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        sam2_checkpoint = "./checkpoints/sam2_hiera_tiny.pt" | 
					
					
						
						| 
							 | 
						        model_cfg = "sam2_hiera_t.yaml" | 
					
					
						
						| 
							 | 
						        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        self.trainer = SAM2VideoTrainer(model_cfg, sam2_checkpoint, device=self.device) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        self.batch_size = 2 | 
					
					
						
						| 
							 | 
						        self.num_frames = 2 | 
					
					
						
						| 
							 | 
						        self.C = 3 | 
					
					
						
						| 
							 | 
						        self.H = 1024 | 
					
					
						
						| 
							 | 
						        self.W = 1024 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        self.videos = torch.randn( | 
					
					
						
						| 
							 | 
						            self.batch_size, self.num_frames, self.C, self.H, self.W | 
					
					
						
						| 
							 | 
						        ).to(self.device) | 
					
					
						
						| 
							 | 
						        self.masks = torch.zeros( | 
					
					
						
						| 
							 | 
						            self.batch_size, self.num_frames, 1, self.H, self.W | 
					
					
						
						| 
							 | 
						        ).to(self.device) | 
					
					
						
						| 
							 | 
						        self.bboxes = torch.tensor( | 
					
					
						
						| 
							 | 
						            [[100, 100, 200, 200], [150, 150, 250, 250]], dtype=torch.float32 | 
					
					
						
						| 
							 | 
						        ).to(self.device) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def test_forward(self): | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        masks, ious = self.trainer(self.videos, self.bboxes, None) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        print("Masks shape:", masks[0].shape) | 
					
					
						
						| 
							 | 
						        print("IoUs shape:", ious[0].shape) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        print("Masks:", masks) | 
					
					
						
						| 
							 | 
						        print("IoUs:", ious) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						if __name__ == "__main__": | 
					
					
						
						| 
							 | 
						    unittest.main() | 
					
					
						
						| 
							 | 
						
 |