addit / addit_blending_utils.py
YoadTew's picture
Add application file
504c7e8
# Copyright (C) 2025 NVIDIA Corporation. All rights reserved.
#
# This work is licensed under the LICENSE file
# located at the root directory.
import torch
import numpy as np
import torch.nn.functional as F
from skimage import filters
import matplotlib.pyplot as plt
from scipy.ndimage import maximum_filter, label, find_objects
def dilate_mask(latents_mask, k, latents_dtype):
# Reshape the mask to 2D (64x64)
mask_2d = latents_mask.view(64, 64)
# Create a square kernel for dilation
kernel = torch.ones(2*k+1, 2*k+1, device=mask_2d.device, dtype=mask_2d.dtype)
# Add two dimensions to make it compatible with conv2d
mask_4d = mask_2d.unsqueeze(0).unsqueeze(0)
# Perform dilation using conv2d
dilated_mask = F.conv2d(mask_4d, kernel.unsqueeze(0).unsqueeze(0), padding=k)
# Threshold the result to get a binary mask
dilated_mask = (dilated_mask > 0).to(mask_2d.dtype)
# Reshape back to the original shape and convert to the desired dtype
dilated_mask = dilated_mask.view(4096, 1).to(latents_dtype)
return dilated_mask
def clipseg_predict(model, processor, image, text, device):
inputs = processor(text=text, images=image, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
preds = outputs.logits.unsqueeze(1)
preds = torch.sigmoid(preds)
otsu_thr = filters.threshold_otsu(preds.cpu().numpy())
subject_mask = (preds > otsu_thr).float()
return subject_mask
def grounding_sam_predict(model, processor, sam_predictor, image, text, device):
inputs = processor(images=image, text=text, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
results = processor.post_process_grounded_object_detection(
outputs,
inputs.input_ids,
box_threshold=0.4,
text_threshold=0.3,
target_sizes=[image.size[::-1]]
)
input_boxes = results[0]["boxes"].cpu().numpy()
if input_boxes.shape[0] == 0:
return torch.ones((64, 64), device=device)
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
sam_predictor.set_image(image)
masks, scores, logits = sam_predictor.predict(
point_coords=None,
point_labels=None,
box=input_boxes,
multimask_output=False,
)
subject_mask = torch.tensor(masks[0], device=device)
return subject_mask
def mask_to_box_sam_predict(mask, sam_predictor, image, text, device):
H, W = image.size
# Resize clipseg mask to image size
mask = F.interpolate(mask.view(1, 1, mask.shape[-2], mask.shape[-1]), size=(H, W), mode='bilinear').view(H, W)
mask_indices = torch.nonzero(mask)
top_left = mask_indices.min(dim=0)[0]
bottom_right = mask_indices.max(dim=0)[0]
# numpy shape [1,4]
input_boxes = np.array([[top_left[1].item(), top_left[0].item(), bottom_right[1].item(), bottom_right[0].item()]])
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
sam_predictor.set_image(image)
masks, scores, logits = sam_predictor.predict(
point_coords=None,
point_labels=None,
box=input_boxes,
multimask_output=True,
)
# subject_mask = torch.tensor(masks[0], device=device)
subject_mask = torch.tensor(np.max(masks, axis=0), device=device)
return subject_mask, input_boxes[0]
def mask_to_mask_sam_predict(mask, sam_predictor, image, text, device):
H, W = (256, 256)
# Resize clipseg mask to image size
mask = F.interpolate(mask.view(1, 1, mask.shape[-2], mask.shape[-1]), size=(H, W), mode='bilinear').view(1, H, W)
mask_input = mask.float().cpu().numpy()
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
sam_predictor.set_image(image)
masks, scores, logits = sam_predictor.predict(
point_coords=None,
point_labels=None,
mask_input=mask_input,
multimask_output=False,
)
subject_mask = torch.tensor(masks[0], device=device)
return subject_mask
def mask_to_points_sam_predict(mask, sam_predictor, image, text, device):
H, W = image.size
# Resize clipseg mask to image size
mask = F.interpolate(mask.view(1, 1, mask.shape[-2], mask.shape[-1]), size=(H, W), mode='bilinear').view(H, W)
mask_indices = torch.nonzero(mask)
# Randomly sample 10 points from the mask
n_points = 2
point_coords = mask_indices[torch.randperm(mask_indices.shape[0])[:n_points]].float().cpu().numpy()
point_labels = torch.ones((n_points,)).float().cpu().numpy()
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
sam_predictor.set_image(image)
masks, scores, logits = sam_predictor.predict(
point_coords=point_coords,
point_labels=point_labels,
multimask_output=False,
)
subject_mask = torch.tensor(masks[0], device=device)
return subject_mask
def attention_to_points_sam_predict(subject_attention, subject_mask, sam_predictor, image, text, device):
H, W = image.size
# Resize clipseg mask to image size
subject_attention = F.interpolate(subject_attention.view(1, 1, subject_attention.shape[-2], subject_attention.shape[-1]), size=(H, W), mode='bilinear').view(H, W)
subject_mask = F.interpolate(subject_mask.view(1, 1, subject_mask.shape[-2], subject_mask.shape[-1]), size=(H, W), mode='bilinear').view(H, W)
# Get mask_bbox
subject_mask_indices = torch.nonzero(subject_mask)
top_left = subject_mask_indices.min(dim=0)[0]
bottom_right = subject_mask_indices.max(dim=0)[0]
box_width = bottom_right[1] - top_left[1]
box_height = bottom_right[0] - top_left[0]
# Define the number of points and minimum distance between points
n_points = 3
max_thr = 0.35
max_attention = torch.max(subject_attention)
min_distance = max(box_width, box_height) // (n_points + 1) # Adjust this value to control spread
# min_distance = max(min_distance, 75)
# Initialize list to store selected points
selected_points = []
# Create a copy of the attention map
remaining_attention = subject_attention.clone()
for _ in range(n_points):
if remaining_attention.max() < max_thr * max_attention:
break
# Find the highest attention point
point = torch.argmax(remaining_attention)
y, x = torch.unravel_index(point, remaining_attention.shape)
y, x = y.item(), x.item()
# Add the point to our list
selected_points.append((x, y))
# Zero out the area around the selected point
y_min = max(0, y - min_distance)
y_max = min(H, y + min_distance + 1)
x_min = max(0, x - min_distance)
x_max = min(W, x + min_distance + 1)
remaining_attention[y_min:y_max, x_min:x_max] = 0
# Convert selected points to numpy array
point_coords = np.array(selected_points)
point_labels = np.ones(point_coords.shape[0], dtype=int)
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
sam_predictor.set_image(image)
masks, scores, logits = sam_predictor.predict(
point_coords=point_coords,
point_labels=point_labels,
multimask_output=False,
)
subject_mask = torch.tensor(masks[0], device=device)
return subject_mask, point_coords
def sam_refine_step(mask, sam_predictor, image, device):
mask_indices = torch.nonzero(mask)
top_left = mask_indices.min(dim=0)[0]
bottom_right = mask_indices.max(dim=0)[0]
# numpy shape [1,4]
input_boxes = np.array([[top_left[1].item(), top_left[0].item(), bottom_right[1].item(), bottom_right[0].item()]])
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
sam_predictor.set_image(image)
masks, scores, logits = sam_predictor.predict(
point_coords=None,
point_labels=None,
box=input_boxes,
multimask_output=True,
)
# subject_mask = torch.tensor(masks[0], device=device)
subject_mask = torch.tensor(np.max(masks, axis=0), device=device)
return subject_mask, input_boxes[0]