Spaces:
Sleeping
Sleeping
File size: 5,763 Bytes
57abc33 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
from models import MaskDecoderHQ
from ppc_decoder import sam_decoder_reg
from segment_anything import sam_model_registry
import torch.nn as nn
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from utils.transforms import ResizeLongestSide
from typing import List
trans = ResizeLongestSide(target_length=1024)
def save_prob_visualization(prob, filename="prob_visualization.png"):
"""
可视化 1xwxh 的概率图并使用 plt.imshow 保存到本地
:param prob: 形状为 1xwxh 的 tensor
:param filename: 保存的文件名,默认为 'prob_visualization.png'
"""
# 将 prob 转换为 numpy 数组
prob_np = prob.squeeze(0).squeeze(0).numpy() # 从 1xwxh 转为 wxh
# 使用 plt.imshow 可视化
plt.imshow(prob_np)
# , cmap='gray', vmin=0, vmax=1) # cmap='gray' 确保图像以灰度显示
plt.axis('off') # 关闭坐标轴
# 保存图像
plt.savefig(filename, bbox_inches='tight', pad_inches=0)
plt.close()
print(f"Probability map saved as {filename}")
def pad_to_square(x: torch.Tensor, target_size: int) -> torch.Tensor:
"""Pad the input tensor to a square shape with the specified target size."""
# Get the current height and width of the image
h, w = x.shape[-2:]
# Calculate padding for height and width
padh = target_size - h
padw = target_size - w
# Pad the tensor to the target size
x = F.pad(x, (0, padw, 0, padh))
return x
def remove_none_values(input_dict):
"""
Remove all items with None as their value from the dictionary.
Args:
input_dict (dict): The dictionary from which to remove None values.
Returns:
dict: A new dictionary with None values removed.
"""
return {key: value for key, value in input_dict.items() if value is not None}
class PPC_SAM():
def __init__(self, model_type="vit_h",
ckpt_vit="pretrained_checkpoint/sam_vit_h_4b8939.pth",
ckpt_ppc="pretrained_checkpoint/ppc_decoder.pth",
ckpt_hq="pretrained_checkpoint/sam_hq_vit_h_decoder.pth",
device = "cpu") -> None:
# Call the parent class's __init__ method first
self.device = device
# Initialize the decoders
self.sam_hq_decoder = MaskDecoderHQ(model_type)
self.ppc_decoder = sam_decoder_reg['default']()
# Load state dictionaries
model_state_hq = torch.load(ckpt_hq, map_location=device)
self.sam_hq_decoder.load_state_dict(model_state_hq)
print(f"Loaded HQ decoder checkpoint from {ckpt_hq}")
model_state_ppc = torch.load(ckpt_ppc, map_location=device)
self.ppc_decoder.load_state_dict(model_state_ppc)
print(f"Loaded PPC decoder checkpoint from {ckpt_ppc}")
# Initialize the SAM model
self.sam = sam_model_registry[model_type](checkpoint=ckpt_vit).to(device)
def predict(self, prompts, multimask_ouput=False):
with torch.no_grad():
self.sam = self.sam.to(self.device)
self.sam_hq_decoder = self.sam_hq_decoder.to(self.device)
self.ppc_decoder = self.ppc_decoder.to(self.device)
batch_input = remove_none_values(prompts[0])
original_size = batch_input["image"].shape[:2]
batch_input["original_size"] = original_size
input_image = trans.apply_image(batch_input["image"])
input_image_torch = torch.as_tensor(input_image, device=self.device)
input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()
batch_input["image"] = input_image_torch
if "boxes" in batch_input:
batch_input["boxes"] = trans.apply_boxes_torch(batch_input["boxes"], original_size=original_size)
if "point_coords" in batch_input:
batch_input["point_coords"] = trans.apply_coords_torch(batch_input["point_coords"], original_size=original_size)
batched_output, interm_embeddings = self.sam([batch_input], multimask_output=multimask_ouput)
batch_len = len(batched_output)
encoder_embedding = torch.cat([batched_output[i_l]['encoder_embedding'] for i_l in range(batch_len)], dim=0)
image_pe = [batched_output[i_l]['image_pe'] for i_l in range(batch_len)]
sparse_embeddings = [batched_output[i_l]['sparse_embeddings'] for i_l in range(batch_len)]
dense_embeddings = [batched_output[i_l]['dense_embeddings'] for i_l in range(batch_len)]
masks_sam_in_hq, masks_hq = self.sam_hq_decoder(
image_embeddings=encoder_embedding,
image_pe=image_pe,
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_ouput,
hq_token_only=False,
interm_embeddings=interm_embeddings,
)
masks_sam = batched_output[0]["masks"]
input_images_ppc = pad_to_square(input_image_torch[None, :,:,:], target_size=1024).float()
mask_ppc = self.ppc_decoder(x_img=input_images_ppc, hidden_states_out=interm_embeddings, low_res_mask=masks_hq)
rescaled_masks_hq=self.sam.postprocess_masks(masks_hq, input_size=input_image_torch.shape[-2:], original_size=original_size)
rescaled_masks_ppc=self.sam.postprocess_masks(mask_ppc, input_size=input_image_torch.shape[-2:], original_size=original_size)
stacked_masks = torch.stack([rescaled_masks_ppc, rescaled_masks_hq, masks_sam.to(torch.uint8)], dim=0).cpu().squeeze(1).squeeze(1)
return stacked_masks, None, None |