|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import numpy as np |
|
import torch.nn.functional as F |
|
from skimage import filters |
|
import matplotlib.pyplot as plt |
|
from scipy.ndimage import maximum_filter, label, find_objects |
|
|
|
def dilate_mask(latents_mask, k, latents_dtype): |
|
|
|
mask_2d = latents_mask.view(64, 64) |
|
|
|
|
|
kernel = torch.ones(2*k+1, 2*k+1, device=mask_2d.device, dtype=mask_2d.dtype) |
|
|
|
|
|
mask_4d = mask_2d.unsqueeze(0).unsqueeze(0) |
|
|
|
|
|
dilated_mask = F.conv2d(mask_4d, kernel.unsqueeze(0).unsqueeze(0), padding=k) |
|
|
|
|
|
dilated_mask = (dilated_mask > 0).to(mask_2d.dtype) |
|
|
|
|
|
dilated_mask = dilated_mask.view(4096, 1).to(latents_dtype) |
|
|
|
return dilated_mask |
|
|
|
def clipseg_predict(model, processor, image, text, device): |
|
inputs = processor(text=text, images=image, return_tensors="pt") |
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
preds = outputs.logits.unsqueeze(1) |
|
preds = torch.sigmoid(preds) |
|
|
|
otsu_thr = filters.threshold_otsu(preds.cpu().numpy()) |
|
subject_mask = (preds > otsu_thr).float() |
|
|
|
return subject_mask |
|
|
|
def grounding_sam_predict(model, processor, sam_predictor, image, text, device): |
|
inputs = processor(images=image, text=text, return_tensors="pt").to(device) |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
results = processor.post_process_grounded_object_detection( |
|
outputs, |
|
inputs.input_ids, |
|
box_threshold=0.4, |
|
text_threshold=0.3, |
|
target_sizes=[image.size[::-1]] |
|
) |
|
|
|
input_boxes = results[0]["boxes"].cpu().numpy() |
|
|
|
if input_boxes.shape[0] == 0: |
|
return torch.ones((64, 64), device=device) |
|
|
|
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): |
|
sam_predictor.set_image(image) |
|
masks, scores, logits = sam_predictor.predict( |
|
point_coords=None, |
|
point_labels=None, |
|
box=input_boxes, |
|
multimask_output=False, |
|
) |
|
|
|
subject_mask = torch.tensor(masks[0], device=device) |
|
|
|
return subject_mask |
|
|
|
def mask_to_box_sam_predict(mask, sam_predictor, image, text, device): |
|
H, W = image.size |
|
|
|
|
|
mask = F.interpolate(mask.view(1, 1, mask.shape[-2], mask.shape[-1]), size=(H, W), mode='bilinear').view(H, W) |
|
mask_indices = torch.nonzero(mask) |
|
top_left = mask_indices.min(dim=0)[0] |
|
bottom_right = mask_indices.max(dim=0)[0] |
|
|
|
|
|
input_boxes = np.array([[top_left[1].item(), top_left[0].item(), bottom_right[1].item(), bottom_right[0].item()]]) |
|
|
|
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): |
|
sam_predictor.set_image(image) |
|
masks, scores, logits = sam_predictor.predict( |
|
point_coords=None, |
|
point_labels=None, |
|
box=input_boxes, |
|
multimask_output=True, |
|
) |
|
|
|
|
|
subject_mask = torch.tensor(np.max(masks, axis=0), device=device) |
|
|
|
return subject_mask, input_boxes[0] |
|
|
|
def mask_to_mask_sam_predict(mask, sam_predictor, image, text, device): |
|
H, W = (256, 256) |
|
|
|
|
|
mask = F.interpolate(mask.view(1, 1, mask.shape[-2], mask.shape[-1]), size=(H, W), mode='bilinear').view(1, H, W) |
|
mask_input = mask.float().cpu().numpy() |
|
|
|
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): |
|
sam_predictor.set_image(image) |
|
masks, scores, logits = sam_predictor.predict( |
|
point_coords=None, |
|
point_labels=None, |
|
mask_input=mask_input, |
|
multimask_output=False, |
|
) |
|
|
|
subject_mask = torch.tensor(masks[0], device=device) |
|
|
|
return subject_mask |
|
|
|
def mask_to_points_sam_predict(mask, sam_predictor, image, text, device): |
|
H, W = image.size |
|
|
|
|
|
mask = F.interpolate(mask.view(1, 1, mask.shape[-2], mask.shape[-1]), size=(H, W), mode='bilinear').view(H, W) |
|
mask_indices = torch.nonzero(mask) |
|
|
|
|
|
n_points = 2 |
|
point_coords = mask_indices[torch.randperm(mask_indices.shape[0])[:n_points]].float().cpu().numpy() |
|
point_labels = torch.ones((n_points,)).float().cpu().numpy() |
|
|
|
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): |
|
sam_predictor.set_image(image) |
|
masks, scores, logits = sam_predictor.predict( |
|
point_coords=point_coords, |
|
point_labels=point_labels, |
|
multimask_output=False, |
|
) |
|
|
|
subject_mask = torch.tensor(masks[0], device=device) |
|
|
|
return subject_mask |
|
|
|
def attention_to_points_sam_predict(subject_attention, subject_mask, sam_predictor, image, text, device): |
|
H, W = image.size |
|
|
|
|
|
subject_attention = F.interpolate(subject_attention.view(1, 1, subject_attention.shape[-2], subject_attention.shape[-1]), size=(H, W), mode='bilinear').view(H, W) |
|
subject_mask = F.interpolate(subject_mask.view(1, 1, subject_mask.shape[-2], subject_mask.shape[-1]), size=(H, W), mode='bilinear').view(H, W) |
|
|
|
|
|
subject_mask_indices = torch.nonzero(subject_mask) |
|
top_left = subject_mask_indices.min(dim=0)[0] |
|
bottom_right = subject_mask_indices.max(dim=0)[0] |
|
box_width = bottom_right[1] - top_left[1] |
|
box_height = bottom_right[0] - top_left[0] |
|
|
|
|
|
n_points = 3 |
|
max_thr = 0.35 |
|
max_attention = torch.max(subject_attention) |
|
min_distance = max(box_width, box_height) // (n_points + 1) |
|
|
|
|
|
|
|
selected_points = [] |
|
|
|
|
|
remaining_attention = subject_attention.clone() |
|
|
|
for _ in range(n_points): |
|
if remaining_attention.max() < max_thr * max_attention: |
|
break |
|
|
|
|
|
point = torch.argmax(remaining_attention) |
|
y, x = torch.unravel_index(point, remaining_attention.shape) |
|
y, x = y.item(), x.item() |
|
|
|
|
|
selected_points.append((x, y)) |
|
|
|
|
|
y_min = max(0, y - min_distance) |
|
y_max = min(H, y + min_distance + 1) |
|
x_min = max(0, x - min_distance) |
|
x_max = min(W, x + min_distance + 1) |
|
remaining_attention[y_min:y_max, x_min:x_max] = 0 |
|
|
|
|
|
point_coords = np.array(selected_points) |
|
point_labels = np.ones(point_coords.shape[0], dtype=int) |
|
|
|
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): |
|
sam_predictor.set_image(image) |
|
masks, scores, logits = sam_predictor.predict( |
|
point_coords=point_coords, |
|
point_labels=point_labels, |
|
multimask_output=False, |
|
) |
|
|
|
subject_mask = torch.tensor(masks[0], device=device) |
|
|
|
return subject_mask, point_coords |
|
|
|
def sam_refine_step(mask, sam_predictor, image, device): |
|
mask_indices = torch.nonzero(mask) |
|
top_left = mask_indices.min(dim=0)[0] |
|
bottom_right = mask_indices.max(dim=0)[0] |
|
|
|
|
|
input_boxes = np.array([[top_left[1].item(), top_left[0].item(), bottom_right[1].item(), bottom_right[0].item()]]) |
|
|
|
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): |
|
sam_predictor.set_image(image) |
|
masks, scores, logits = sam_predictor.predict( |
|
point_coords=None, |
|
point_labels=None, |
|
box=input_boxes, |
|
multimask_output=True, |
|
) |
|
|
|
|
|
subject_mask = torch.tensor(np.max(masks, axis=0), device=device) |
|
|
|
return subject_mask, input_boxes[0] |
|
|
|
|