|
import numpy as np |
|
import gradio as gr |
|
import cv2 |
|
from copy import deepcopy |
|
import torch |
|
from torchvision import transforms |
|
from PIL import Image, ImageDraw, ImageFont |
|
|
|
from sam.efficient_sam.build_efficient_sam import build_efficient_sam_vits |
|
from src.utils.utils import resize_numpy_image |
|
|
|
sam = build_efficient_sam_vits() |
|
|
|
def show_point_or_box(image, global_points): |
|
|
|
if len(global_points) == 1: |
|
image = cv2.circle(image, global_points[0], 10, (0, 0, 255), -1) |
|
|
|
if len(global_points) == 2: |
|
p1 = global_points[0] |
|
p2 = global_points[1] |
|
image = cv2.rectangle(image,(int(p1[0]),int(p1[1])),(int(p2[0]),int(p2[1])),(0,0,255),2) |
|
|
|
return image |
|
|
|
def segment_with_points( |
|
image, |
|
original_image, |
|
global_points, |
|
global_point_label, |
|
evt: gr.SelectData, |
|
img_direction, |
|
save_dir = "./tmp" |
|
): |
|
if original_image is None: |
|
original_image = image |
|
else: |
|
image = original_image |
|
if img_direction is None: |
|
img_direction = original_image |
|
x, y = evt.index[0], evt.index[1] |
|
image_path = None |
|
mask_path = None |
|
if len(global_points) == 0: |
|
global_points.append([x, y]) |
|
global_point_label.append(2) |
|
image_with_point= show_point_or_box(image.copy(), global_points) |
|
return image_with_point, original_image, None, global_points, global_point_label |
|
elif len(global_points) == 1: |
|
global_points.append([x, y]) |
|
global_point_label.append(3) |
|
x1, y1 = global_points[0] |
|
x2, y2 = global_points[1] |
|
if x1 < x2 and y1 >= y2: |
|
global_points[0][0] = x1 |
|
global_points[0][1] = y2 |
|
global_points[1][0] = x2 |
|
global_points[1][1] = y1 |
|
elif x1 >= x2 and y1 < y2: |
|
global_points[0][0] = x2 |
|
global_points[0][1] = y1 |
|
global_points[1][0] = x1 |
|
global_points[1][1] = y2 |
|
elif x1 >= x2 and y1 >= y2: |
|
global_points[0][0] = x2 |
|
global_points[0][1] = y2 |
|
global_points[1][0] = x1 |
|
global_points[1][1] = y1 |
|
image_with_point = show_point_or_box(image.copy(), global_points) |
|
|
|
input_point = np.array(global_points) |
|
input_label = np.array(global_point_label) |
|
pts_sampled = torch.reshape(torch.tensor(input_point), [1, 1, -1, 2]) |
|
pts_labels = torch.reshape(torch.tensor(input_label), [1, 1, -1]) |
|
img_tensor = transforms.ToTensor()(image) |
|
|
|
predicted_logits, predicted_iou = sam( |
|
img_tensor[None, ...], |
|
pts_sampled, |
|
pts_labels, |
|
) |
|
mask = torch.ge(predicted_logits[0, 0, 0, :, :], 0).float().cpu().detach().numpy() |
|
mask_image = (mask*255.).astype(np.uint8) |
|
return image_with_point, original_image, mask_image, global_points, global_point_label |
|
else: |
|
global_points=[[x, y]] |
|
global_point_label=[2] |
|
image_with_point= show_point_or_box(image.copy(), global_points) |
|
return image_with_point, original_image, None, global_points, global_point_label |
|
|
|
|
|
def segment_with_points_paste( |
|
image, |
|
original_image, |
|
global_points, |
|
global_point_label, |
|
image_b, |
|
evt: gr.SelectData, |
|
dx, |
|
dy, |
|
resize_scale |
|
|
|
): |
|
if original_image is None: |
|
original_image = image |
|
else: |
|
image = original_image |
|
x, y = evt.index[0], evt.index[1] |
|
if len(global_points) == 0: |
|
global_points.append([x, y]) |
|
global_point_label.append(2) |
|
image_with_point= show_point_or_box(image.copy(), global_points) |
|
return image_with_point, original_image, None, global_points, global_point_label, None |
|
elif len(global_points) == 1: |
|
global_points.append([x, y]) |
|
global_point_label.append(3) |
|
x1, y1 = global_points[0] |
|
x2, y2 = global_points[1] |
|
if x1 < x2 and y1 >= y2: |
|
global_points[0][0] = x1 |
|
global_points[0][1] = y2 |
|
global_points[1][0] = x2 |
|
global_points[1][1] = y1 |
|
elif x1 >= x2 and y1 < y2: |
|
global_points[0][0] = x2 |
|
global_points[0][1] = y1 |
|
global_points[1][0] = x1 |
|
global_points[1][1] = y2 |
|
elif x1 >= x2 and y1 >= y2: |
|
global_points[0][0] = x2 |
|
global_points[0][1] = y2 |
|
global_points[1][0] = x1 |
|
global_points[1][1] = y1 |
|
image_with_point = show_point_or_box(image.copy(), global_points) |
|
|
|
input_point = np.array(global_points) |
|
input_label = np.array(global_point_label) |
|
pts_sampled = torch.reshape(torch.tensor(input_point), [1, 1, -1, 2]) |
|
pts_labels = torch.reshape(torch.tensor(input_label), [1, 1, -1]) |
|
img_tensor = transforms.ToTensor()(image) |
|
|
|
predicted_logits, predicted_iou = sam( |
|
img_tensor[None, ...], |
|
pts_sampled, |
|
pts_labels, |
|
) |
|
mask = torch.ge(predicted_logits[0, 0, 0, :, :], 0).float().cpu().detach().numpy() |
|
mask_uint8 = (mask*255.).astype(np.uint8) |
|
|
|
return image_with_point, original_image, paste_with_mask_and_offset(image, image_b, mask_uint8, dx, dy, resize_scale), global_points, global_point_label, mask_uint8 |
|
else: |
|
global_points=[[x, y]] |
|
global_point_label=[2] |
|
image_with_point= show_point_or_box(image.copy(), global_points) |
|
return image_with_point, original_image, None, global_points, global_point_label, None |
|
|
|
def paste_with_mask_and_offset(image_a, image_b, mask, x_offset=0, y_offset=0, delta=1): |
|
try: |
|
numpy_mask = np.array(mask) |
|
y_coords, x_coords = np.nonzero(numpy_mask) |
|
x_min = x_coords.min() |
|
x_max = x_coords.max() |
|
y_min = y_coords.min() |
|
y_max = y_coords.max() |
|
target_center_x = int((x_min + x_max) / 2) |
|
target_center_y = int((y_min + y_max) / 2) |
|
|
|
image_a = Image.fromarray(image_a) |
|
image_b = Image.fromarray(image_b) |
|
mask = Image.fromarray(mask) |
|
|
|
if image_a.size != mask.size: |
|
mask = mask.resize(image_a.size) |
|
|
|
cropped_image = Image.composite(image_a, Image.new('RGBA', image_a.size, (0, 0, 0, 0)), mask) |
|
x_b = int(target_center_x * (image_b.width / cropped_image.width)) |
|
y_b = int(target_center_y * (image_b.height / cropped_image.height)) |
|
x_offset = x_offset - int((delta - 1) * x_b) |
|
y_offset = y_offset - int((delta - 1) * y_b) |
|
cropped_image = cropped_image.resize(image_b.size) |
|
new_size = (int(cropped_image.width * delta), int(cropped_image.height * delta)) |
|
cropped_image = cropped_image.resize(new_size) |
|
image_b.putalpha(128) |
|
result_image = Image.new('RGBA', image_b.size, (0, 0, 0, 0)) |
|
result_image.paste(image_b, (0, 0)) |
|
result_image.paste(cropped_image, (x_offset, y_offset), mask=cropped_image) |
|
|
|
return result_image |
|
except: |
|
return None |
|
|
|
def upload_image_move(img, original_image): |
|
if original_image is not None: |
|
return original_image |
|
else: |
|
return img |
|
|
|
def fun_clear(*args): |
|
result = [] |
|
for arg in args: |
|
if isinstance(arg, list): |
|
result.append([]) |
|
else: |
|
result.append(None) |
|
return tuple(result) |
|
|
|
def clear_points(img): |
|
image, mask = img["image"], np.float32(img["mask"][:, :, 0]) / 255. |
|
if mask.sum() > 0: |
|
mask = np.uint8(mask > 0) |
|
masked_img = mask_image(image, 1 - mask, color=[0, 0, 0], alpha=0.3) |
|
else: |
|
masked_img = image.copy() |
|
|
|
return [], masked_img |
|
|
|
def get_point(img, sel_pix, evt: gr.SelectData): |
|
sel_pix.append(evt.index) |
|
points = [] |
|
for idx, point in enumerate(sel_pix): |
|
if idx % 2 == 0: |
|
cv2.circle(img, tuple(point), 10, (0, 0, 255), -1) |
|
else: |
|
cv2.circle(img, tuple(point), 10, (255, 0, 0), -1) |
|
points.append(tuple(point)) |
|
if len(points) == 2: |
|
cv2.arrowedLine(img, points[0], points[1], (255, 255, 255), 4, tipLength=0.5) |
|
points = [] |
|
return img if isinstance(img, np.ndarray) else np.array(img) |
|
|
|
def calculate_translation_percentage(ori_shape, selected_points): |
|
dx = selected_points[1][0] - selected_points[0][0] |
|
dy = selected_points[1][1] - selected_points[0][1] |
|
dx_percentage = dx / ori_shape[1] |
|
dy_percentage = dy / ori_shape[0] |
|
|
|
return dx_percentage, dy_percentage |
|
|
|
def get_point_move(original_image, img, sel_pix, evt: gr.SelectData): |
|
if original_image is not None: |
|
img = original_image.copy() |
|
else: |
|
original_image = img.copy() |
|
if len(sel_pix)<2: |
|
sel_pix.append(evt.index) |
|
else: |
|
sel_pix = [evt.index] |
|
points = [] |
|
dx, dy = 0, 0 |
|
for idx, point in enumerate(sel_pix): |
|
if idx % 2 == 0: |
|
cv2.circle(img, tuple(point), 10, (0, 0, 255), -1) |
|
else: |
|
cv2.circle(img, tuple(point), 10, (255, 0, 0), -1) |
|
points.append(tuple(point)) |
|
if len(points) == 2: |
|
cv2.arrowedLine(img, points[0], points[1], (255, 255, 255), 4, tipLength=0.5) |
|
ori_shape = original_image.shape |
|
dx, dy = calculate_translation_percentage(original_image.shape, sel_pix) |
|
points = [] |
|
img = np.array(img) |
|
|
|
return img, original_image, sel_pix, dx, dy |
|
|
|
def store_img(img): |
|
image, mask = img["image"], np.float32(img["mask"][:, :, 0]) / 255. |
|
if mask.sum() > 0: |
|
mask = np.uint8(mask > 0) |
|
masked_img = mask_image(image, 1 - mask, color=[0, 0, 0], alpha=0.3) |
|
else: |
|
masked_img = image.copy() |
|
|
|
return image, masked_img, mask |
|
|
|
def store_img_move(img, mask=None): |
|
if mask is not None: |
|
image = img["image"] |
|
return image, None, mask |
|
image, mask = img["image"], np.float32(img["mask"][:, :, 0]) / 255. |
|
if mask.sum() > 0: |
|
mask = np.uint8(mask > 0) |
|
masked_img = mask_image(image, 1 - mask, color=[0, 0, 0], alpha=0.3) |
|
else: |
|
masked_img = image.copy() |
|
|
|
return image, masked_img, (mask*255.).astype(np.uint8) |
|
|
|
def mask_image(image, mask, color=[255,0,0], alpha=0.5, max_resolution=None): |
|
""" Overlay mask on image for visualization purpose. |
|
Args: |
|
image (H, W, 3) or (H, W): input image |
|
mask (H, W): mask to be overlaid |
|
color: the color of overlaid mask |
|
alpha: the transparency of the mask |
|
""" |
|
if max_resolution is not None: |
|
image, _ = resize_numpy_image(image, max_resolution*max_resolution) |
|
mask = cv2.resize(mask, (image.shape[1], image.shape[0]),interpolation=cv2.INTER_NEAREST) |
|
|
|
out = deepcopy(image) |
|
img = deepcopy(image) |
|
img[mask == 1] = color |
|
out = cv2.addWeighted(img, alpha, out, 1-alpha, 0, out) |
|
contours = cv2.findContours(np.uint8(deepcopy(mask)), cv2.RETR_TREE, |
|
cv2.CHAIN_APPROX_SIMPLE)[-2:] |
|
return out |