from dataclasses import dataclass from typing import Any, List, Dict, Optional, Union, Tuple import os import cv2 import torch import requests import numpy as np from PIL import Image from transformers import AutoModelForMaskGeneration, AutoProcessor, pipeline # In[2]: @dataclass class BoundingBox: xmin: int ymin: int xmax: int ymax: int @property def xyxy(self) -> List[float]: return [self.xmin, self.ymin, self.xmax, self.ymax] @dataclass class DetectionResult: score: float label: str box: BoundingBox mask: Optional[np.array] = None @classmethod def from_dict(cls, detection_dict: Dict) -> 'DetectionResult': return cls(score=detection_dict['score'], label=detection_dict['label'], box=BoundingBox(xmin=detection_dict['box']['xmin'], ymin=detection_dict['box']['ymin'], xmax=detection_dict['box']['xmax'], ymax=detection_dict['box']['ymax'])) def mask_to_polygon(mask: np.ndarray) -> List[List[int]]: # Find contours in the binary mask contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) # Find the contour with the largest area largest_contour = max(contours, key=cv2.contourArea) # Extract the vertices of the contour polygon = largest_contour.reshape(-1, 2).tolist() return polygon def polygon_to_mask(polygon: List[Tuple[int, int]], image_shape: Tuple[int, int]) -> np.ndarray: """ Convert a polygon to a segmentation mask. Args: - polygon (list): List of (x, y) coordinates representing the vertices of the polygon. - image_shape (tuple): Shape of the image (height, width) for the mask. Returns: - np.ndarray: Segmentation mask with the polygon filled. """ # Create an empty mask mask = np.zeros(image_shape, dtype=np.uint8) # Convert polygon to an array of points pts = np.array(polygon, dtype=np.int32) # Fill the polygon with white color (255) cv2.fillPoly(mask, [pts], color=(255,)) return mask def load_image(image_str: str) -> Image.Image: if image_str.startswith("http"): image = Image.open(requests.get(image_str, stream=True).raw).convert("RGB") else: image = Image.open(image_str).convert("RGB") return image def get_boxes(results: DetectionResult) -> List[List[List[float]]]: boxes = [] for result in results: xyxy = result.box.xyxy boxes.append(xyxy) return [boxes] def refine_masks(masks: torch.BoolTensor, polygon_refinement: bool = False) -> List[np.ndarray]: masks = masks.cpu().float() masks = masks.permute(0, 2, 3, 1) masks = masks.mean(axis=-1) masks = (masks > 0).int() masks = masks.numpy().astype(np.uint8) masks = list(masks) if polygon_refinement: for idx, mask in enumerate(masks): shape = mask.shape polygon = mask_to_polygon(mask) mask = polygon_to_mask(polygon, shape) masks[idx] = mask return masks # In[6]: def detect( image: Image.Image, labels: List[str], threshold: float = 0.3, detector_id: Optional[str] = None ) -> List[Dict[str, Any]]: """ Use Grounding DINO to detect a set of labels in an image in a zero-shot fashion. """ device = "cuda" if torch.cuda.is_available() else "cpu" detector_id = detector_id if detector_id is not None else "IDEA-Research/grounding-dino-tiny" object_detector = pipeline(model=detector_id, task="zero-shot-object-detection", device=device) labels = [label if label.endswith(".") else label+"." for label in labels] results = object_detector(image, candidate_labels=labels, threshold=threshold) results = [DetectionResult.from_dict(result) for result in results] return results def segment( image: Image.Image, detection_results: List[Dict[str, Any]], polygon_refinement: bool = False, segmenter_id: Optional[str] = None ) -> List[DetectionResult]: """ Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes. """ device = "cuda" if torch.cuda.is_available() else "cpu" segmenter_id = segmenter_id if segmenter_id is not None else "facebook/sam-vit-base" segmentator = AutoModelForMaskGeneration.from_pretrained(segmenter_id).to(device) processor = AutoProcessor.from_pretrained(segmenter_id) boxes = get_boxes(detection_results) inputs = processor(images=image, input_boxes=boxes, return_tensors="pt").to(device) outputs = segmentator(**inputs) masks = processor.post_process_masks( masks=outputs.pred_masks, original_sizes=inputs.original_sizes, reshaped_input_sizes=inputs.reshaped_input_sizes )[0] masks = refine_masks(masks, polygon_refinement) for detection_result, mask in zip(detection_results, masks): detection_result.mask = mask return detection_results def grounded_segmentation( image: Union[Image.Image, str], labels: List[str], threshold: float = 0.3, polygon_refinement: bool = False, detector_id: Optional[str] = None, segmenter_id: Optional[str] = None ) -> Tuple[np.ndarray, List[DetectionResult]]: if isinstance(image, str): image = load_image(image) detections = detect(image, labels, threshold, detector_id) detections = segment(image, detections, polygon_refinement, segmenter_id) return image, detections # In[7]: # save clipped images def cut_image(image, mask, box): ny_image = np.array(image) cut = cv2.bitwise_and(ny_image, ny_image, mask=mask.astype(np.uint8)*255) x0, y0, x1, y1 = map(int, box.xyxy) cropped = cut[y0:y1, x0:x1] cropped_bgr = cv2.cvtColor(cropped, cv2.COLOR_RGB2BGR) return cropped_bgr