|
r""" Visualize model predictions """ |
|
import os |
|
|
|
from PIL import Image |
|
import numpy as np |
|
import torchvision.transforms as transforms |
|
|
|
from fewshot_data.common import utils |
|
|
|
|
|
class Visualizer: |
|
|
|
@classmethod |
|
def initialize(cls, visualize): |
|
cls.visualize = visualize |
|
if not visualize: |
|
return |
|
|
|
cls.colors = {'red': (255, 50, 50), 'blue': (102, 140, 255)} |
|
for key, value in cls.colors.items(): |
|
cls.colors[key] = tuple([c / 255 for c in cls.colors[key]]) |
|
|
|
|
|
|
|
cls.mean_img = [0.5] * 3 |
|
cls.std_img = [0.5] * 3 |
|
cls.to_pil = transforms.ToPILImage() |
|
cls.vis_path = './vis/' |
|
if not os.path.exists(cls.vis_path): os.makedirs(cls.vis_path) |
|
|
|
@classmethod |
|
def visualize_prediction_batch(cls, spt_img_b, spt_mask_b, qry_img_b, qry_mask_b, pred_mask_b, cls_id_b, batch_idx, iou_b=None): |
|
spt_img_b = utils.to_cpu(spt_img_b) |
|
spt_mask_b = utils.to_cpu(spt_mask_b) |
|
qry_img_b = utils.to_cpu(qry_img_b) |
|
qry_mask_b = utils.to_cpu(qry_mask_b) |
|
pred_mask_b = utils.to_cpu(pred_mask_b) |
|
cls_id_b = utils.to_cpu(cls_id_b) |
|
|
|
for sample_idx, (spt_img, spt_mask, qry_img, qry_mask, pred_mask, cls_id) in \ |
|
enumerate(zip(spt_img_b, spt_mask_b, qry_img_b, qry_mask_b, pred_mask_b, cls_id_b)): |
|
iou = iou_b[sample_idx] if iou_b is not None else None |
|
cls.visualize_prediction(spt_img, spt_mask, qry_img, qry_mask, pred_mask, cls_id, batch_idx, sample_idx, True, iou) |
|
|
|
@classmethod |
|
def to_numpy(cls, tensor, type): |
|
if type == 'img': |
|
return np.array(cls.to_pil(cls.unnormalize(tensor))).astype(np.uint8) |
|
elif type == 'mask': |
|
return np.array(tensor).astype(np.uint8) |
|
else: |
|
raise Exception('Undefined tensor type: %s' % type) |
|
|
|
@classmethod |
|
def visualize_prediction(cls, spt_imgs, spt_masks, qry_img, qry_mask, pred_mask, cls_id, batch_idx, sample_idx, label, iou=None): |
|
|
|
spt_color = cls.colors['blue'] |
|
qry_color = cls.colors['red'] |
|
pred_color = cls.colors['red'] |
|
|
|
spt_imgs = [cls.to_numpy(spt_img, 'img') for spt_img in spt_imgs] |
|
spt_pils = [cls.to_pil(spt_img) for spt_img in spt_imgs] |
|
spt_masks = [cls.to_numpy(spt_mask, 'mask') for spt_mask in spt_masks] |
|
spt_masked_pils = [Image.fromarray(cls.apply_mask(spt_img, spt_mask, spt_color)) for spt_img, spt_mask in zip(spt_imgs, spt_masks)] |
|
|
|
qry_img = cls.to_numpy(qry_img, 'img') |
|
qry_pil = cls.to_pil(qry_img) |
|
qry_mask = cls.to_numpy(qry_mask, 'mask') |
|
pred_mask = cls.to_numpy(pred_mask, 'mask') |
|
pred_masked_pil = Image.fromarray(cls.apply_mask(qry_img.astype(np.uint8), pred_mask.astype(np.uint8), pred_color)) |
|
qry_masked_pil = Image.fromarray(cls.apply_mask(qry_img.astype(np.uint8), qry_mask.astype(np.uint8), qry_color)) |
|
|
|
merged_pil = cls.merge_image_pair(spt_masked_pils + [pred_masked_pil, qry_masked_pil]) |
|
|
|
iou = iou.item() if iou else 0.0 |
|
merged_pil.save(cls.vis_path + '%d_%d_class-%d_iou-%.2f' % (batch_idx, sample_idx, cls_id, iou) + '.jpg') |
|
|
|
@classmethod |
|
def merge_image_pair(cls, pil_imgs): |
|
r""" Horizontally aligns a pair of pytorch tensor images (3, H, W) and returns PIL object """ |
|
|
|
canvas_width = sum([pil.size[0] for pil in pil_imgs]) |
|
canvas_height = max([pil.size[1] for pil in pil_imgs]) |
|
canvas = Image.new('RGB', (canvas_width, canvas_height)) |
|
|
|
xpos = 0 |
|
for pil in pil_imgs: |
|
canvas.paste(pil, (xpos, 0)) |
|
xpos += pil.size[0] |
|
|
|
return canvas |
|
|
|
@classmethod |
|
def apply_mask(cls, image, mask, color, alpha=0.5): |
|
r""" Apply mask to the given image. """ |
|
for c in range(3): |
|
image[:, :, c] = np.where(mask == 1, |
|
image[:, :, c] * |
|
(1 - alpha) + alpha * color[c] * 255, |
|
image[:, :, c]) |
|
return image |
|
|
|
@classmethod |
|
def unnormalize(cls, img): |
|
img = img.clone() |
|
for im_channel, mean, std in zip(img, cls.mean_img, cls.std_img): |
|
im_channel.mul_(std).add_(mean) |
|
return img |
|
|