File size: 4,372 Bytes
57746f1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
r""" Visualize model predictions """
import os
from PIL import Image
import numpy as np
import torchvision.transforms as transforms
from fewshot_data.common import utils
class Visualizer:
@classmethod
def initialize(cls, visualize):
cls.visualize = visualize
if not visualize:
return
cls.colors = {'red': (255, 50, 50), 'blue': (102, 140, 255)}
for key, value in cls.colors.items():
cls.colors[key] = tuple([c / 255 for c in cls.colors[key]])
# cls.mean_img = [0.485, 0.456, 0.406]
# cls.std_img = [0.229, 0.224, 0.225]
cls.mean_img = [0.5] * 3
cls.std_img = [0.5] * 3
cls.to_pil = transforms.ToPILImage()
cls.vis_path = './vis/'
if not os.path.exists(cls.vis_path): os.makedirs(cls.vis_path)
@classmethod
def visualize_prediction_batch(cls, spt_img_b, spt_mask_b, qry_img_b, qry_mask_b, pred_mask_b, cls_id_b, batch_idx, iou_b=None):
spt_img_b = utils.to_cpu(spt_img_b)
spt_mask_b = utils.to_cpu(spt_mask_b)
qry_img_b = utils.to_cpu(qry_img_b)
qry_mask_b = utils.to_cpu(qry_mask_b)
pred_mask_b = utils.to_cpu(pred_mask_b)
cls_id_b = utils.to_cpu(cls_id_b)
for sample_idx, (spt_img, spt_mask, qry_img, qry_mask, pred_mask, cls_id) in \
enumerate(zip(spt_img_b, spt_mask_b, qry_img_b, qry_mask_b, pred_mask_b, cls_id_b)):
iou = iou_b[sample_idx] if iou_b is not None else None
cls.visualize_prediction(spt_img, spt_mask, qry_img, qry_mask, pred_mask, cls_id, batch_idx, sample_idx, True, iou)
@classmethod
def to_numpy(cls, tensor, type):
if type == 'img':
return np.array(cls.to_pil(cls.unnormalize(tensor))).astype(np.uint8)
elif type == 'mask':
return np.array(tensor).astype(np.uint8)
else:
raise Exception('Undefined tensor type: %s' % type)
@classmethod
def visualize_prediction(cls, spt_imgs, spt_masks, qry_img, qry_mask, pred_mask, cls_id, batch_idx, sample_idx, label, iou=None):
spt_color = cls.colors['blue']
qry_color = cls.colors['red']
pred_color = cls.colors['red']
spt_imgs = [cls.to_numpy(spt_img, 'img') for spt_img in spt_imgs]
spt_pils = [cls.to_pil(spt_img) for spt_img in spt_imgs]
spt_masks = [cls.to_numpy(spt_mask, 'mask') for spt_mask in spt_masks]
spt_masked_pils = [Image.fromarray(cls.apply_mask(spt_img, spt_mask, spt_color)) for spt_img, spt_mask in zip(spt_imgs, spt_masks)]
qry_img = cls.to_numpy(qry_img, 'img')
qry_pil = cls.to_pil(qry_img)
qry_mask = cls.to_numpy(qry_mask, 'mask')
pred_mask = cls.to_numpy(pred_mask, 'mask')
pred_masked_pil = Image.fromarray(cls.apply_mask(qry_img.astype(np.uint8), pred_mask.astype(np.uint8), pred_color))
qry_masked_pil = Image.fromarray(cls.apply_mask(qry_img.astype(np.uint8), qry_mask.astype(np.uint8), qry_color))
merged_pil = cls.merge_image_pair(spt_masked_pils + [pred_masked_pil, qry_masked_pil])
iou = iou.item() if iou else 0.0
merged_pil.save(cls.vis_path + '%d_%d_class-%d_iou-%.2f' % (batch_idx, sample_idx, cls_id, iou) + '.jpg')
@classmethod
def merge_image_pair(cls, pil_imgs):
r""" Horizontally aligns a pair of pytorch tensor images (3, H, W) and returns PIL object """
canvas_width = sum([pil.size[0] for pil in pil_imgs])
canvas_height = max([pil.size[1] for pil in pil_imgs])
canvas = Image.new('RGB', (canvas_width, canvas_height))
xpos = 0
for pil in pil_imgs:
canvas.paste(pil, (xpos, 0))
xpos += pil.size[0]
return canvas
@classmethod
def apply_mask(cls, image, mask, color, alpha=0.5):
r""" Apply mask to the given image. """
for c in range(3):
image[:, :, c] = np.where(mask == 1,
image[:, :, c] *
(1 - alpha) + alpha * color[c] * 255,
image[:, :, c])
return image
@classmethod
def unnormalize(cls, img):
img = img.clone()
for im_channel, mean, std in zip(img, cls.mean_img, cls.std_img):
im_channel.mul_(std).add_(mean)
return img
|