Spaces:
Running
on
L4
Running
on
L4
from io import BytesIO | |
import base64 | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import torch | |
def fig_to_base64(fig): | |
buf = BytesIO() | |
fig.savefig(buf, format='png', bbox_inches='tight') | |
plt.close(fig) | |
buf.seek(0) | |
return base64.b64encode(buf.getvalue()).decode() | |
def show_mask(mask, ax, random_color=False): | |
if random_color: | |
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) | |
else: | |
color = np.array([30/255, 144/255, 255/255, 0.6]) | |
h, w = mask.shape[-2:] | |
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) | |
ax.imshow(mask_image) | |
def show_box(box, ax): | |
x0, y0 = box[0], box[1] | |
w, h = box[2] - box[0], box[3] - box[1] | |
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) | |
def show_points(coords, labels, ax, marker_size=375): | |
pos_points = coords[labels==1] | |
neg_points = coords[labels==0] | |
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) | |
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) | |
def show_boxes_on_image_base64(raw_image, boxes): | |
fig, ax = plt.subplots(figsize=(10,10)) | |
ax.imshow(raw_image) | |
for box in boxes: | |
show_box(box, ax) | |
ax.axis('off') | |
return fig_to_base64(fig) | |
def show_points_on_image_base64(raw_image, input_points, input_labels=None): | |
fig, ax = plt.subplots(figsize=(10,10)) | |
ax.imshow(raw_image) | |
input_points = np.array(input_points) | |
labels = np.ones_like(input_points[:, 0]) if input_labels is None else np.array(input_labels) | |
show_points(input_points, labels, ax) | |
ax.axis('off') | |
return fig_to_base64(fig) | |
def show_points_and_boxes_on_image_base64(raw_image, boxes, input_points, input_labels=None): | |
fig, ax = plt.subplots(figsize=(10,10)) | |
ax.imshow(raw_image) | |
input_points = np.array(input_points) | |
labels = np.ones_like(input_points[:, 0]) if input_labels is None else np.array(input_labels) | |
show_points(input_points, labels, ax) | |
for box in boxes: | |
show_box(box, ax) | |
ax.axis('off') | |
return fig_to_base64(fig) | |
def show_masks_on_image_base64(raw_image, masks, scores): | |
if len(masks.shape) == 4: | |
masks = masks.squeeze() | |
if scores.shape[0] == 1: | |
scores = scores.squeeze() | |
nb_predictions = scores.shape[-1] | |
print(f"Number of predictions: {nb_predictions}") | |
fig, axes = plt.subplots(1, nb_predictions, figsize=(5 * nb_predictions, 5)) | |
if nb_predictions == 1: | |
axes = [axes] | |
for i, (mask, score) in enumerate(zip(masks, scores)): | |
print(i) | |
mask = mask.cpu().detach().numpy() | |
axes[i].imshow(np.array(raw_image)) | |
show_mask(mask, axes[i]) | |
axes[i].title.set_text(f"Mask {i+1}, Score: {score.item():.3f}") | |
axes[i].axis("off") | |
return fig_to_base64(fig) | |
def show_first_mask_on_image_base64(raw_image, masks, scores): | |
if masks.ndim == 4: | |
mask = masks[0, 0] | |
elif masks.ndim == 3: | |
mask = masks[0] | |
else: | |
mask = masks | |
if isinstance(mask, torch.Tensor): | |
mask = mask.cpu().detach().numpy() | |
score_text = "" | |
if scores is not None: | |
if isinstance(scores, torch.Tensor): | |
scores = scores.flatten() | |
score = scores[0].item() | |
else: | |
score = float(np.array(scores).flatten()[0]) | |
score_text = f"Score: {score:.3f}" | |
fig, ax = plt.subplots(figsize=(5, 5)) | |
ax.imshow(np.array(raw_image)) | |
show_mask(mask, ax) | |
ax.set_title(score_text) | |
ax.axis("off") | |
return fig_to_base64(fig) | |
def show_all_annotations_on_image_base64(raw_image, masks=None, scores=None, boxes=None, input_points=None, input_labels=None, model_name=None): | |
fig, ax = plt.subplots(figsize=(10, 10)) | |
ax.imshow(np.array(raw_image)) | |
if masks is not None: | |
if masks.ndim == 4: | |
mask = masks[0, 0] | |
elif masks.ndim == 3: | |
mask = masks[0] | |
else: | |
mask = masks | |
if isinstance(mask, torch.Tensor): | |
mask = mask.cpu().detach().numpy() | |
show_mask(mask, ax) | |
if scores is not None: | |
if isinstance(scores, torch.Tensor): | |
scores = scores.flatten() | |
score = scores[0].item() | |
else: | |
score = float(np.array(scores).flatten()[0]) | |
#ax.set_title(f"{model_name} - Score: {score:.3f}") | |
ax.set_title(f"{model_name}") | |
if input_points is not None: | |
input_points = np.array(input_points) | |
labels = np.ones_like(input_points[:, 0]) if input_labels is None else np.array(input_labels) | |
show_points(input_points, labels, ax) | |
if boxes is not None: | |
for box in boxes: | |
show_box(box, ax) | |
ax.axis("off") | |
return fig_to_base64(fig) | |