|
|
|
from dataclasses import dataclass |
|
from typing import Any, List, Dict, Optional, Union, Tuple |
|
import os |
|
|
|
import cv2 |
|
import torch |
|
import requests |
|
import numpy as np |
|
from PIL import Image |
|
from transformers import AutoModelForMaskGeneration, AutoProcessor, pipeline |
|
|
|
|
|
|
|
|
|
@dataclass |
|
class BoundingBox: |
|
xmin: int |
|
ymin: int |
|
xmax: int |
|
ymax: int |
|
|
|
@property |
|
def xyxy(self) -> List[float]: |
|
return [self.xmin, self.ymin, self.xmax, self.ymax] |
|
|
|
@dataclass |
|
class DetectionResult: |
|
score: float |
|
label: str |
|
box: BoundingBox |
|
mask: Optional[np.array] = None |
|
|
|
@classmethod |
|
def from_dict(cls, detection_dict: Dict) -> 'DetectionResult': |
|
return cls(score=detection_dict['score'], |
|
label=detection_dict['label'], |
|
box=BoundingBox(xmin=detection_dict['box']['xmin'], |
|
ymin=detection_dict['box']['ymin'], |
|
xmax=detection_dict['box']['xmax'], |
|
ymax=detection_dict['box']['ymax'])) |
|
|
|
def mask_to_polygon(mask: np.ndarray) -> List[List[int]]: |
|
|
|
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
|
|
|
|
largest_contour = max(contours, key=cv2.contourArea) |
|
|
|
|
|
polygon = largest_contour.reshape(-1, 2).tolist() |
|
|
|
return polygon |
|
|
|
def polygon_to_mask(polygon: List[Tuple[int, int]], image_shape: Tuple[int, int]) -> np.ndarray: |
|
""" |
|
Convert a polygon to a segmentation mask. |
|
|
|
Args: |
|
- polygon (list): List of (x, y) coordinates representing the vertices of the polygon. |
|
- image_shape (tuple): Shape of the image (height, width) for the mask. |
|
|
|
Returns: |
|
- np.ndarray: Segmentation mask with the polygon filled. |
|
""" |
|
|
|
mask = np.zeros(image_shape, dtype=np.uint8) |
|
|
|
|
|
pts = np.array(polygon, dtype=np.int32) |
|
|
|
|
|
cv2.fillPoly(mask, [pts], color=(255,)) |
|
|
|
return mask |
|
|
|
def load_image(image_str: str) -> Image.Image: |
|
if image_str.startswith("http"): |
|
image = Image.open(requests.get(image_str, stream=True).raw).convert("RGB") |
|
else: |
|
image = Image.open(image_str).convert("RGB") |
|
|
|
return image |
|
|
|
def get_boxes(results: DetectionResult) -> List[List[List[float]]]: |
|
boxes = [] |
|
for result in results: |
|
xyxy = result.box.xyxy |
|
boxes.append(xyxy) |
|
|
|
return [boxes] |
|
|
|
def refine_masks(masks: torch.BoolTensor, polygon_refinement: bool = False) -> List[np.ndarray]: |
|
masks = masks.cpu().float() |
|
masks = masks.permute(0, 2, 3, 1) |
|
masks = masks.mean(axis=-1) |
|
masks = (masks > 0).int() |
|
masks = masks.numpy().astype(np.uint8) |
|
masks = list(masks) |
|
|
|
if polygon_refinement: |
|
for idx, mask in enumerate(masks): |
|
shape = mask.shape |
|
polygon = mask_to_polygon(mask) |
|
mask = polygon_to_mask(polygon, shape) |
|
masks[idx] = mask |
|
|
|
return masks |
|
|
|
|
|
|
|
|
|
|
|
def detect( |
|
image: Image.Image, |
|
labels: List[str], |
|
threshold: float = 0.3, |
|
detector_id: Optional[str] = None |
|
) -> List[Dict[str, Any]]: |
|
""" |
|
Use Grounding DINO to detect a set of labels in an image in a zero-shot fashion. |
|
""" |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
detector_id = detector_id if detector_id is not None else "IDEA-Research/grounding-dino-tiny" |
|
object_detector = pipeline(model=detector_id, task="zero-shot-object-detection", device=device) |
|
|
|
labels = [label if label.endswith(".") else label+"." for label in labels] |
|
|
|
results = object_detector(image, candidate_labels=labels, threshold=threshold) |
|
results = [DetectionResult.from_dict(result) for result in results] |
|
|
|
return results |
|
|
|
def segment( |
|
image: Image.Image, |
|
detection_results: List[Dict[str, Any]], |
|
polygon_refinement: bool = False, |
|
segmenter_id: Optional[str] = None |
|
) -> List[DetectionResult]: |
|
""" |
|
Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes. |
|
""" |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
segmenter_id = segmenter_id if segmenter_id is not None else "facebook/sam-vit-base" |
|
|
|
segmentator = AutoModelForMaskGeneration.from_pretrained(segmenter_id).to(device) |
|
processor = AutoProcessor.from_pretrained(segmenter_id) |
|
|
|
boxes = get_boxes(detection_results) |
|
inputs = processor(images=image, input_boxes=boxes, return_tensors="pt").to(device) |
|
|
|
outputs = segmentator(**inputs) |
|
masks = processor.post_process_masks( |
|
masks=outputs.pred_masks, |
|
original_sizes=inputs.original_sizes, |
|
reshaped_input_sizes=inputs.reshaped_input_sizes |
|
)[0] |
|
|
|
masks = refine_masks(masks, polygon_refinement) |
|
|
|
for detection_result, mask in zip(detection_results, masks): |
|
detection_result.mask = mask |
|
|
|
return detection_results |
|
|
|
def grounded_segmentation( |
|
image: Union[Image.Image, str], |
|
labels: List[str], |
|
threshold: float = 0.3, |
|
polygon_refinement: bool = False, |
|
detector_id: Optional[str] = None, |
|
segmenter_id: Optional[str] = None |
|
) -> Tuple[np.ndarray, List[DetectionResult]]: |
|
if isinstance(image, str): |
|
image = load_image(image) |
|
|
|
detections = detect(image, labels, threshold, detector_id) |
|
detections = segment(image, detections, polygon_refinement, segmenter_id) |
|
|
|
return image, detections |
|
|
|
|
|
|
|
|
|
|
|
|
|
def cut_image(image, mask, box): |
|
ny_image = np.array(image) |
|
cut = cv2.bitwise_and(ny_image, ny_image, mask=mask.astype(np.uint8)*255) |
|
x0, y0, x1, y1 = map(int, box.xyxy) |
|
cropped = cut[y0:y1, x0:x1] |
|
cropped_bgr = cv2.cvtColor(cropped, cv2.COLOR_RGB2BGR) |
|
return cropped_bgr |
|
|