File size: 5,763 Bytes
57abc33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from models import MaskDecoderHQ
from ppc_decoder import sam_decoder_reg
from segment_anything import sam_model_registry
import torch.nn as nn
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from utils.transforms import ResizeLongestSide
from typing import List

trans = ResizeLongestSide(target_length=1024)

def save_prob_visualization(prob, filename="prob_visualization.png"):
    """
    可视化 1xwxh 的概率图并使用 plt.imshow 保存到本地
    :param prob: 形状为 1xwxh 的 tensor
    :param filename: 保存的文件名,默认为 'prob_visualization.png'
    """
    # 将 prob 转换为 numpy 数组
    prob_np = prob.squeeze(0).squeeze(0).numpy()  # 从 1xwxh 转为 wxh

    # 使用 plt.imshow 可视化
    plt.imshow(prob_np)
               # , cmap='gray', vmin=0, vmax=1)  # cmap='gray' 确保图像以灰度显示
    plt.axis('off')  # 关闭坐标轴
    
    # 保存图像
    plt.savefig(filename, bbox_inches='tight', pad_inches=0)
    plt.close()
    print(f"Probability map saved as {filename}")

def pad_to_square(x: torch.Tensor, target_size: int) -> torch.Tensor:
    """Pad the input tensor to a square shape with the specified target size."""
    # Get the current height and width of the image
    h, w = x.shape[-2:]
    
    # Calculate padding for height and width
    padh = target_size - h
    padw = target_size - w
    
    # Pad the tensor to the target size
    x = F.pad(x, (0, padw, 0, padh))
    return x

def remove_none_values(input_dict):
    """
    Remove all items with None as their value from the dictionary.

    Args:
        input_dict (dict): The dictionary from which to remove None values.

    Returns:
        dict: A new dictionary with None values removed.
    """
    return {key: value for key, value in input_dict.items() if value is not None}

class PPC_SAM():
    def __init__(self, model_type="vit_h", 
                 ckpt_vit="pretrained_checkpoint/sam_vit_h_4b8939.pth", 
                 ckpt_ppc="pretrained_checkpoint/ppc_decoder.pth",
                 ckpt_hq="pretrained_checkpoint/sam_hq_vit_h_decoder.pth",
                 device = "cpu") -> None:
        # Call the parent class's __init__ method first

        self.device = device

        # Initialize the decoders
        self.sam_hq_decoder = MaskDecoderHQ(model_type)
        self.ppc_decoder = sam_decoder_reg['default']()

        # Load state dictionaries
        model_state_hq = torch.load(ckpt_hq, map_location=device)
        self.sam_hq_decoder.load_state_dict(model_state_hq)
        print(f"Loaded HQ decoder checkpoint from {ckpt_hq}")

        model_state_ppc = torch.load(ckpt_ppc, map_location=device)
        self.ppc_decoder.load_state_dict(model_state_ppc)
        print(f"Loaded PPC decoder checkpoint from {ckpt_ppc}")

        # Initialize the SAM model
        self.sam = sam_model_registry[model_type](checkpoint=ckpt_vit).to(device)
        

    def predict(self, prompts, multimask_ouput=False):
        with torch.no_grad():
            self.sam = self.sam.to(self.device)
            self.sam_hq_decoder = self.sam_hq_decoder.to(self.device)
            self.ppc_decoder = self.ppc_decoder.to(self.device)
            
            batch_input = remove_none_values(prompts[0])
            original_size = batch_input["image"].shape[:2]
            batch_input["original_size"] = original_size

            input_image = trans.apply_image(batch_input["image"])
            input_image_torch = torch.as_tensor(input_image, device=self.device)
            input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()
            batch_input["image"] = input_image_torch

            if "boxes" in batch_input:
                batch_input["boxes"] = trans.apply_boxes_torch(batch_input["boxes"], original_size=original_size)
            if "point_coords" in batch_input:
                batch_input["point_coords"] = trans.apply_coords_torch(batch_input["point_coords"], original_size=original_size)


            batched_output, interm_embeddings = self.sam([batch_input], multimask_output=multimask_ouput)   

            batch_len = len(batched_output)
            encoder_embedding = torch.cat([batched_output[i_l]['encoder_embedding'] for i_l in range(batch_len)], dim=0)
            image_pe = [batched_output[i_l]['image_pe'] for i_l in range(batch_len)]
            sparse_embeddings = [batched_output[i_l]['sparse_embeddings'] for i_l in range(batch_len)]
            dense_embeddings = [batched_output[i_l]['dense_embeddings'] for i_l in range(batch_len)]
                
            masks_sam_in_hq, masks_hq = self.sam_hq_decoder(
                image_embeddings=encoder_embedding,
                image_pe=image_pe,
                sparse_prompt_embeddings=sparse_embeddings,
                dense_prompt_embeddings=dense_embeddings,
                multimask_output=multimask_ouput,
                hq_token_only=False,
                interm_embeddings=interm_embeddings,
            )

            masks_sam = batched_output[0]["masks"]
          
            input_images_ppc = pad_to_square(input_image_torch[None, :,:,:], target_size=1024).float()
            mask_ppc = self.ppc_decoder(x_img=input_images_ppc, hidden_states_out=interm_embeddings, low_res_mask=masks_hq)

            rescaled_masks_hq=self.sam.postprocess_masks(masks_hq, input_size=input_image_torch.shape[-2:], original_size=original_size)
            rescaled_masks_ppc=self.sam.postprocess_masks(mask_ppc, input_size=input_image_torch.shape[-2:], original_size=original_size)

            stacked_masks = torch.stack([rescaled_masks_ppc, rescaled_masks_hq, masks_sam.to(torch.uint8)], dim=0).cpu().squeeze(1).squeeze(1)
        return stacked_masks, None, None