Spaces:
Runtime error
Runtime error
| # coding=utf-8 | |
| # Copyright 2024 The Google Research Authors. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """CAM utils.""" | |
| # pylint: disable=g-importing-member | |
| import os | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| from scipy.ndimage import binary_fill_holes | |
| import torch | |
| from torchvision.transforms import Compose | |
| from torchvision.transforms import Normalize | |
| from torchvision.transforms import Resize | |
| from torchvision.transforms import ToTensor | |
| # pylint: disable=g-import-not-at-top | |
| try: | |
| from torchvision.transforms import InterpolationMode | |
| BICUBIC = InterpolationMode.BICUBIC | |
| except ImportError: | |
| BICUBIC = Image.BICUBIC | |
| _CONTOUR_INDEX = 1 if cv2.__version__.split('.')[0] == '3' else 0 | |
| def _convert_image_to_rgb(image): | |
| return image.convert('RGB') | |
| def _transform_resize(h, w): | |
| return Compose([ | |
| Resize((h, w), interpolation=BICUBIC), | |
| _convert_image_to_rgb, | |
| ToTensor(), | |
| Normalize( | |
| (0.48145466, 0.4578275, 0.40821073), | |
| (0.26862954, 0.26130258, 0.27577711), | |
| ), | |
| ]) | |
| def img_ms_and_flip(image, ori_height, ori_width, scales=1.0, patch_size=16): | |
| """Resizes and flips the image.""" | |
| if isinstance(scales, float): | |
| scales = [scales] | |
| all_imgs = [] | |
| for scale in scales: | |
| preprocess = _transform_resize( | |
| int(np.ceil(scale * int(ori_height) / patch_size) * patch_size), | |
| int(np.ceil(scale * int(ori_width) / patch_size) * patch_size), | |
| ) | |
| image = preprocess(image) | |
| image_ori = image | |
| image_flip = torch.flip(image, [-1]) | |
| all_imgs.append(image_ori) | |
| all_imgs.append(image_flip) | |
| return all_imgs | |
| def reshape_transform(tensor, height=28, width=28): | |
| tensor = tensor.permute(1, 0, 2) | |
| result = tensor[:, 1:, :].reshape( | |
| tensor.size(0), height, width, tensor.size(2) | |
| ) | |
| # Bring the channels to the first dimension, like in CNNs. | |
| result = result.transpose(2, 3).transpose(1, 2) | |
| return result | |
| def vis_mask(image, mask, mask_color): | |
| # switch the height and width of image | |
| # image = image.transpose(1, 0, 2) | |
| if mask.shape[0] != image.shape[0] or mask.shape[1] != image.shape[1]: | |
| mask = cv2.resize(mask, (image.shape[1], image.shape[0])) | |
| fg = mask > 0.5 | |
| rgb = np.copy(image) | |
| rgb[fg] = (rgb[fg] * 0.3 + np.array(mask_color) * 0.7).astype(np.uint8) | |
| return Image.fromarray(rgb) | |
| def scoremap2bbox(scoremap, threshold, multi_contour_eval=False): | |
| """Get bounding boxes from scoremap.""" | |
| height, width = scoremap.shape | |
| scoremap_image = np.expand_dims((scoremap * 255).astype(np.uint8), 2) | |
| while True: | |
| _, thr_gray_heatmap = cv2.threshold( | |
| src=scoremap_image, | |
| thresh=int(threshold * np.max(scoremap_image)), | |
| maxval=255, | |
| type=cv2.THRESH_BINARY, | |
| ) | |
| if thr_gray_heatmap.max() > 0 or threshold <= 0: | |
| break | |
| threshold -= 0.1 | |
| contours = cv2.findContours( | |
| image=thr_gray_heatmap, mode=cv2.RETR_TREE, method=cv2.CHAIN_APPROX_SIMPLE | |
| )[_CONTOUR_INDEX] | |
| # if len(contours) == 0: | |
| if not contours: | |
| return np.asarray([[0, 0, 0, 0]]), 1 | |
| if not multi_contour_eval: | |
| contours = [max(contours, key=cv2.contourArea)] | |
| estimated_boxes = [] | |
| for contour in contours: | |
| x, y, w, h = cv2.boundingRect(contour) | |
| x0, y0, x1, y1 = x, y, x + w, y + h | |
| x1 = min(x1, width - 1) | |
| y1 = min(y1, height - 1) | |
| estimated_boxes.append([x0, y0, x1, y1]) | |
| return np.asarray(estimated_boxes), len(contours) | |
| def mask2chw(arr): | |
| # Find the row and column indices where the array is 1 | |
| rows, cols = np.where(arr == 1) | |
| # Calculate center of the mask | |
| center_y = int(np.mean(rows)) | |
| center_x = int(np.mean(cols)) | |
| # Calculate height and width of the mask | |
| height = rows.max() - rows.min() + 1 | |
| width = cols.max() - cols.min() + 1 | |
| return (center_y, center_x), height, width | |
| def unpad(image_array, pad=None): | |
| if pad is not None: | |
| left, top, width, height = pad | |
| image_array = image_array[top : top + height, left : left + width, :] | |
| return image_array | |
| def apply_visual_prompts( | |
| image_array, | |
| mask, | |
| visual_prompt_type=('circle',), | |
| visualize=False, | |
| color=(255, 0, 0), | |
| thickness=1, | |
| blur_strength=(15, 15), | |
| ): | |
| """Applies visual prompts to the image.""" | |
| prompted_image = image_array.copy() | |
| if 'blur' in visual_prompt_type: | |
| # blur the part out side the mask | |
| # Blur the entire image | |
| blurred = cv2.GaussianBlur(prompted_image.copy(), blur_strength, 0) | |
| # Get the sharp region using the mask | |
| sharp_region = cv2.bitwise_and( | |
| prompted_image.copy(), | |
| prompted_image.copy(), | |
| mask=np.clip(mask, 0, 255).astype(np.uint8), | |
| ) | |
| # Get the blurred region using the inverted mask | |
| inv_mask = 1 - mask | |
| blurred_region = (blurred * inv_mask[:, :, None]).astype(np.uint8) | |
| # Combine the sharp and blurred regions | |
| prompted_image = cv2.add(sharp_region, blurred_region) | |
| if 'gray' in visual_prompt_type: | |
| gray = cv2.cvtColor(prompted_image.copy(), cv2.COLOR_BGR2GRAY) | |
| # make gray part 3 channel | |
| gray = np.stack([gray, gray, gray], axis=-1) | |
| # Get the sharp region using the mask | |
| color_region = cv2.bitwise_and( | |
| prompted_image.copy(), | |
| prompted_image.copy(), | |
| mask=np.clip(mask, 0, 255).astype(np.uint8), | |
| ) | |
| # Get the blurred region using the inverted mask | |
| inv_mask = 1 - mask | |
| gray_region = (gray * inv_mask[:, :, None]).astype(np.uint8) | |
| # Combine the sharp and blurred regions | |
| prompted_image = cv2.add(color_region, gray_region) | |
| if 'black' in visual_prompt_type: | |
| prompted_image = cv2.bitwise_and( | |
| prompted_image.copy(), | |
| prompted_image.copy(), | |
| mask=np.clip(mask, 0, 255).astype(np.uint8), | |
| ) | |
| if 'circle' in visual_prompt_type: | |
| mask_center, mask_height, mask_width = mask2chw(mask) | |
| center_coordinates = (mask_center[1], mask_center[0]) | |
| axes_length = (mask_width // 2, mask_height // 2) | |
| prompted_image = cv2.ellipse( | |
| prompted_image, | |
| center_coordinates, | |
| axes_length, | |
| 0, | |
| 0, | |
| 360, | |
| color, | |
| thickness, | |
| ) | |
| if 'rectangle' in visual_prompt_type: | |
| mask_center, mask_height, mask_width = mask2chw(mask) | |
| # center_coordinates = (mask_center[1], mask_center[0]) | |
| # axes_length = (mask_width // 2, mask_height // 2) | |
| start_point = ( | |
| mask_center[1] - mask_width // 2, | |
| mask_center[0] - mask_height // 2, | |
| ) | |
| end_point = ( | |
| mask_center[1] + mask_width // 2, | |
| mask_center[0] + mask_height // 2, | |
| ) | |
| prompted_image = cv2.rectangle( | |
| prompted_image, start_point, end_point, color, thickness | |
| ) | |
| if 'contour' in visual_prompt_type: | |
| # Find the contours of the mask | |
| # fill holes for the mask | |
| mask = binary_fill_holes(mask) | |
| contours, _ = cv2.findContours( | |
| mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE | |
| ) | |
| # Draw the contours on the image | |
| prompted_image = cv2.drawContours( | |
| prompted_image.copy(), contours, -1, color, thickness | |
| ) | |
| if visualize: | |
| cv2.imwrite(os.path.join('masked_img.png'), prompted_image) | |
| prompted_image = Image.fromarray(prompted_image.astype(np.uint8)) | |
| return prompted_image | |