Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import torch | |
| from torch.nn import functional as F | |
| import cv2 | |
| from detectron2.data import MetadataCatalog | |
| from detectron2.structures import BitMasks | |
| from detectron2.utils.visualizer import ColorMode, Visualizer | |
| import open_clip | |
| from sam2.build_sam import build_sam2 | |
| from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator | |
| from .modeling.meta_arch.mask_adapter_head import build_mask_adapter | |
| from sam2.sam2_image_predictor import SAM2ImagePredictor | |
| from PIL import Image | |
| PIXEL_MEAN = [122.7709383, 116.7460125, 104.09373615] | |
| PIXEL_STD = [68.5005327, 66.6321579, 70.32316305] | |
| class OpenVocabVisualizer(Visualizer): | |
| def __init__(self, img_rgb, metadata=None, scale=1.0, instance_mode=ColorMode.IMAGE, class_names=None): | |
| super().__init__(img_rgb, metadata, scale, instance_mode) | |
| self.class_names = class_names | |
| def draw_sem_seg(self, sem_seg, area_threshold=None, alpha=0.6): | |
| """ | |
| Draw semantic segmentation predictions/labels. | |
| Args: | |
| sem_seg (Tensor or ndarray): the segmentation of shape (H, W). | |
| Each value is the integer label of the pixel. | |
| area_threshold (int): segments with less than `area_threshold` are not drawn. | |
| alpha (float): the larger it is, the more opaque the segmentations are. | |
| Returns: | |
| output (VisImage): image object with visualizations. | |
| """ | |
| if isinstance(sem_seg, torch.Tensor): | |
| sem_seg = sem_seg.numpy() | |
| labels, areas = np.unique(sem_seg, return_counts=True) | |
| sorted_idxs = np.argsort(-areas).tolist() | |
| labels = labels[sorted_idxs] | |
| class_names = self.class_names if self.class_names is not None else self.metadata.stuff_classes | |
| for label in filter(lambda l: l < len(class_names), labels): | |
| try: | |
| mask_color = [x / 255 for x in self.metadata.stuff_colors[label]] | |
| except (AttributeError, IndexError): | |
| mask_color = None | |
| binary_mask = (sem_seg == label).astype(np.uint8) | |
| text = class_names[label] | |
| self.draw_binary_mask( | |
| binary_mask, | |
| color=mask_color, | |
| edge_color=(1.0, 1.0, 240.0 / 255), | |
| text=text, | |
| alpha=alpha, | |
| area_threshold=area_threshold, | |
| ) | |
| return self.output | |
| class SAMVisualizationDemo(object): | |
| def __init__(self, cfg, granularity, sam2, clip_model ,mask_adapter, instance_mode=ColorMode.IMAGE, parallel=False,): | |
| self.metadata = MetadataCatalog.get( | |
| cfg.DATASETS.TEST[0] if len(cfg.DATASETS.TEST) else "__unused" | |
| ) | |
| self.cpu_device = torch.device("cpu") | |
| self.instance_mode = instance_mode | |
| self.parallel = parallel | |
| self.granularity = granularity | |
| self.sam2 = sam2 | |
| self.predictor = SAM2AutomaticMaskGenerator(sam2, points_per_batch=16, | |
| pred_iou_thresh=0.8, | |
| stability_score_thresh=0.7, | |
| crop_n_layers=0, | |
| crop_n_points_downscale_factor=2, | |
| min_mask_region_area=100) | |
| self.clip_model = clip_model | |
| self.mask_adapter = mask_adapter | |
| def extract_features_convnext(self, x): | |
| out = {} | |
| x = self.clip_model.visual.trunk.stem(x) | |
| out['stem'] = x.contiguous() # os4 | |
| for i in range(4): | |
| x = self.clip_model.visual.trunk.stages[i](x) | |
| out[f'res{i+2}'] = x.contiguous() # res 2 (os4), 3 (os8), 4 (os16), 5 (os32) | |
| x = self.clip_model.visual.trunk.norm_pre(x) | |
| out['clip_vis_dense'] = x.contiguous() | |
| return out | |
| def visual_prediction_forward_convnext(self, x): | |
| batch, num_query, channel = x.shape | |
| x = x.reshape(batch*num_query, channel, 1, 1) # fake 2D input | |
| x = self.clip_model.visual.trunk.head(x) | |
| x = self.clip_model.visual.head(x) | |
| return x.view(batch, num_query, x.shape[-1]) # B x num_queries x 640 | |
| def visual_prediction_forward_convnext_2d(self, x): | |
| clip_vis_dense = self.clip_model.visual.trunk.head.norm(x) | |
| clip_vis_dense = self.clip_model.visual.trunk.head.drop(clip_vis_dense.permute(0, 2, 3, 1)) | |
| clip_vis_dense = self.clip_model.visual.head(clip_vis_dense).permute(0, 3, 1, 2) | |
| return clip_vis_dense | |
| def run_on_image(self, ori_image, class_names): | |
| height, width, _ = ori_image.shape | |
| if width > height: | |
| new_width = 896 | |
| new_height = int((new_width / width) * height) | |
| else: | |
| new_height = 896 | |
| new_width = int((new_height / height) * width) | |
| image = cv2.resize(ori_image, (new_width, new_height)) | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| ori_image = cv2.cvtColor(ori_image, cv2.COLOR_BGR2RGB) | |
| visualizer = OpenVocabVisualizer(ori_image, self.metadata, instance_mode=self.instance_mode, class_names=class_names) | |
| with torch.no_grad():#, torch.cuda.amp.autocast(): | |
| masks = self.predictor.generate(image) | |
| pred_masks = [masks[i]['segmentation'][None,:,:] for i in range(len(masks))] | |
| pred_masks = np.row_stack(pred_masks) | |
| pred_masks = BitMasks(pred_masks) | |
| image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1)) | |
| pixel_mean = torch.tensor(PIXEL_MEAN).view(-1, 1, 1) | |
| pixel_std = torch.tensor(PIXEL_STD).view(-1, 1, 1) | |
| image = (image - pixel_mean) / pixel_std | |
| image = image.unsqueeze(0) | |
| if len(class_names) == 1: | |
| class_names.append('others') | |
| txts = [f'a photo of {cls_name}' for cls_name in class_names] | |
| text = open_clip.tokenize(txts) | |
| with torch.no_grad(): | |
| self.clip_model.cuda() | |
| text_features = self.clip_model.encode_text(text.cuda()) | |
| text_features /= text_features.norm(dim=-1, keepdim=True) | |
| features = self.extract_features_convnext(image.cuda().float()) | |
| clip_feature = features['clip_vis_dense'] | |
| clip_vis_dense = self.visual_prediction_forward_convnext_2d(clip_feature) | |
| semantic_activation_maps = self.mask_adapter(clip_vis_dense, pred_masks.tensor.unsqueeze(0).float().cuda()) | |
| maps_for_pooling = F.interpolate(semantic_activation_maps, size=clip_feature.shape[-2:], | |
| mode='bilinear', align_corners=False) | |
| B, C = clip_feature.size(0),clip_feature.size(1) | |
| N = maps_for_pooling.size(1) | |
| num_instances = N // 16 | |
| maps_for_pooling = F.softmax(F.logsigmoid(maps_for_pooling).view(B, N,-1), dim=-1) | |
| pooled_clip_feature = torch.bmm(maps_for_pooling, clip_feature.view(B, C, -1).permute(0, 2, 1)) | |
| pooled_clip_feature = self.visual_prediction_forward_convnext(pooled_clip_feature) | |
| pooled_clip_feature = (pooled_clip_feature.reshape(B,num_instances, 16, -1).mean(dim=-2).contiguous()) | |
| class_preds = (100.0 * pooled_clip_feature @ text_features.T).softmax(dim=-1) | |
| class_preds = class_preds.squeeze(0) | |
| select_cls = torch.zeros_like(class_preds) | |
| max_scores, select_mask = torch.max(class_preds, dim=0) | |
| if len(class_names) == 2 and class_names[-1] == 'others': | |
| select_mask = select_mask[:-1] | |
| if self.granularity < 1: | |
| thr_scores = max_scores * self.granularity | |
| select_mask = [] | |
| if len(class_names) == 2 and class_names[-1] == 'others': | |
| thr_scores = thr_scores[:-1] | |
| for i, thr in enumerate(thr_scores): | |
| cls_pred = class_preds[:,i] | |
| locs = torch.where(cls_pred > thr) | |
| select_mask.extend(locs[0].tolist()) | |
| for idx in select_mask: | |
| select_cls[idx] = class_preds[idx] | |
| semseg = torch.einsum("qc,qhw->chw", select_cls.float(), pred_masks.tensor.float().cuda()) | |
| r = semseg | |
| blank_area = (r[0] == 0) | |
| pred_mask = r.argmax(dim=0).to('cpu') | |
| pred_mask[blank_area] = 255 | |
| pred_mask = np.array(pred_mask, dtype=int) | |
| pred_mask = cv2.resize(pred_mask, (width, height), interpolation=cv2.INTER_NEAREST) | |
| vis_output = visualizer.draw_sem_seg( | |
| pred_mask | |
| ) | |
| return None, vis_output | |
| class SAMPointVisualizationDemo(object): | |
| def __init__(self, cfg, granularity, sam2, clip_model ,mask_adapter, instance_mode=ColorMode.IMAGE, parallel=False): | |
| self.metadata = MetadataCatalog.get( | |
| cfg.DATASETS.TEST[0] if len(cfg.DATASETS.TEST) else "__unused" | |
| ) | |
| self.cpu_device = torch.device("cpu") | |
| self.instance_mode = instance_mode | |
| self.parallel = parallel | |
| self.granularity = granularity | |
| self.sam2 = sam2 | |
| self.predictor = SAM2ImagePredictor(sam2) | |
| self.clip_model = clip_model | |
| self.mask_adapter = mask_adapter | |
| from .data.datasets import openseg_classes | |
| COCO_CATEGORIES_pan = openseg_classes.get_coco_categories_with_prompt_eng() | |
| #COCO_CATEGORIES_seg = openseg_classes.get_coco_stuff_categories_with_prompt_eng() | |
| thing_classes = [k["name"] for k in COCO_CATEGORIES_pan if k["isthing"] == 1] | |
| stuff_classes = [k["name"] for k in COCO_CATEGORIES_pan] | |
| #print(coco_metadata) | |
| lvis_classes = open("./mask_adapter/data/datasets/lvis_1203_with_prompt_eng.txt", 'r').read().splitlines() | |
| lvis_classes = [x[x.find(':')+1:] for x in lvis_classes] | |
| self.class_names = thing_classes + stuff_classes + lvis_classes | |
| self.text_embedding = torch.from_numpy(np.load("./text_embedding/lvis_coco_text_embedding.npy")).to("cuda") | |
| self.class_names = self._load_class_names() | |
| def _load_class_names(self): | |
| from .data.datasets import openseg_classes | |
| COCO_CATEGORIES_pan = openseg_classes.get_coco_categories_with_prompt_eng() | |
| thing_classes = [k["name"] for k in COCO_CATEGORIES_pan if k["isthing"] == 1] | |
| stuff_classes = [k["name"] for k in COCO_CATEGORIES_pan] | |
| lvis_classes = open("./mask_adapter/data/datasets/lvis_1203_with_prompt_eng.txt", 'r').read().splitlines() | |
| lvis_classes = [x[x.find(':')+1:] for x in lvis_classes] | |
| return thing_classes + stuff_classes + lvis_classes | |
| def extract_features_convnext(self, x): | |
| out = {} | |
| x = self.clip_model.visual.trunk.stem(x) | |
| out['stem'] = x.contiguous() # os4 | |
| for i in range(4): | |
| x = self.clip_model.visual.trunk.stages[i](x) | |
| out[f'res{i+2}'] = x.contiguous() # res 2 (os4), 3 (os8), 4 (os16), 5 (os32) | |
| x = self.clip_model.visual.trunk.norm_pre(x) | |
| out['clip_vis_dense'] = x.contiguous() | |
| return out | |
| def visual_prediction_forward_convnext(self, x): | |
| batch, num_query, channel = x.shape | |
| x = x.reshape(batch*num_query, channel, 1, 1) # fake 2D input | |
| x = self.clip_model.visual.trunk.head(x) | |
| x = self.clip_model.visual.head(x) | |
| return x.view(batch, num_query, x.shape[-1]) # B x num_queries x 640 | |
| def visual_prediction_forward_convnext_2d(self, x): | |
| clip_vis_dense = self.clip_model.visual.trunk.head.norm(x) | |
| clip_vis_dense = self.clip_model.visual.trunk.head.drop(clip_vis_dense.permute(0, 2, 3, 1)) | |
| clip_vis_dense = self.clip_model.visual.head(clip_vis_dense).permute(0, 3, 1, 2) | |
| return clip_vis_dense | |
| def run_on_image_with_points(self, ori_image, points): | |
| height, width, _ = ori_image.shape | |
| image = ori_image | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| ori_image = cv2.cvtColor(ori_image, cv2.COLOR_BGR2RGB) | |
| input_point = np.array(points) | |
| input_label = np.array([1]) | |
| with torch.no_grad(): | |
| self.predictor.set_image(image) | |
| masks, _, _ = self.predictor.predict(point_coords=input_point, point_labels=input_label, multimask_output=False) | |
| pred_masks = BitMasks(masks) | |
| image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1)) | |
| pixel_mean = torch.tensor(PIXEL_MEAN).view(-1, 1, 1) | |
| pixel_std = torch.tensor(PIXEL_STD).view(-1, 1, 1) | |
| image = (image - pixel_mean) / pixel_std | |
| image = image.unsqueeze(0) | |
| # txts = [f'a photo of {cls_name}' for cls_name in self.class_names] | |
| # text = open_clip.tokenize(txts) | |
| with torch.no_grad(): | |
| self.clip_model.cuda() | |
| # text_features = self.clip_model.encode_text(text.cuda()) | |
| # text_features /= text_features.norm(dim=-1, keepdim=True) | |
| #np.save("/home/yongkangli/Mask-Adapter/text_embedding/lvis_coco_text_embedding.npy", text_features.cpu().numpy()) | |
| text_features = self.text_embedding | |
| features = self.extract_features_convnext(image.cuda().float()) | |
| clip_feature = features['clip_vis_dense'] | |
| clip_vis_dense = self.visual_prediction_forward_convnext_2d(clip_feature) | |
| semantic_activation_maps = self.mask_adapter(clip_vis_dense, pred_masks.tensor.unsqueeze(0).float().cuda()) | |
| maps_for_pooling = F.interpolate(semantic_activation_maps, size=clip_feature.shape[-2:], mode='bilinear', align_corners=False) | |
| B, C = clip_feature.size(0), clip_feature.size(1) | |
| N = maps_for_pooling.size(1) | |
| num_instances = N // 16 | |
| maps_for_pooling = F.softmax(F.logsigmoid(maps_for_pooling).view(B, N,-1), dim=-1) | |
| pooled_clip_feature = torch.bmm(maps_for_pooling, clip_feature.view(B, C, -1).permute(0, 2, 1)) | |
| pooled_clip_feature = self.visual_prediction_forward_convnext(pooled_clip_feature) | |
| pooled_clip_feature = (pooled_clip_feature.reshape(B, num_instances, 16, -1).mean(dim=-2).contiguous()) | |
| class_preds = (100.0 * pooled_clip_feature @ text_features.T).softmax(dim=-1) | |
| class_preds = class_preds.squeeze(0) | |
| # Resize mask to match original image size | |
| pred_mask = cv2.resize(masks.squeeze(0), (width, height), interpolation=cv2.INTER_NEAREST) # Resize mask to match original image size | |
| # Create an overlay for the mask with a transparent background (using alpha transparency) | |
| overlay = ori_image.copy() | |
| mask_colored = np.zeros_like(ori_image) | |
| mask_colored[pred_mask == 1] = [234, 103, 112] # Green color for the mask | |
| # Apply the mask with transparency (alpha blending) | |
| alpha = 0.5 | |
| cv2.addWeighted(mask_colored, alpha, overlay, 1 - alpha, 0, overlay) | |
| # Draw boundary (contours) on the overlay | |
| contours, _ = cv2.findContours(pred_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| cv2.drawContours(overlay, contours, -1, (255, 255, 255), 2) # White boundary | |
| # Add label based on the class with the highest score | |
| max_scores, max_score_idx = class_preds.max(dim=1) # Find the max score across the class predictions | |
| label = f"{self.class_names[max_score_idx.item()]}: {max_scores.item():.2f}" | |
| # Dynamically place the label near the clicked point | |
| text_x = min(width - 200, points[0][0] + 20) # Add some offset from the point | |
| text_y = min(height - 30, points[0][1] + 20) # Ensure the text does not go out of bounds | |
| # Put text near the point | |
| cv2.putText(overlay, label, (text_x, text_y), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2) | |
| return None, Image.fromarray(overlay) |