File size: 1,966 Bytes
9e15541
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import numpy as np
import torch

import matplotlib.pyplot as plt


def draw_bbox(im, size):
    b, c, h, w = im.shape
    h2, w2 = (h - size) // 2, (w - size) // 2
    marker = np.tile(np.array([[1.0], [0.0], [0.0]]), (1, size))
    marker = torch.FloatTensor(marker)
    im[:, :, h2, w2 : w2 + size] = marker
    im[:, :, h2 + size, w2 : w2 + size] = marker
    im[:, :, h2 : h2 + size, w2] = marker
    im[:, :, h2 : h2 + size, w2 + size] = marker
    return im


def plot_image_grid(
    images, rows, cols, directions=None, imsize=(2, 2), title=None, show=True
):
    fig, axs = plt.subplots(
        rows,
        cols,
        gridspec_kw={"wspace": 0, "hspace": 0},
        squeeze=True,
        figsize=(rows * imsize[0], cols * imsize[1]),
    )
    for i, image in enumerate(images):
        axs[i % rows][i // rows].axis("off")
        if directions is not None:
            axs[i % rows][i // rows].arrow(
                32,
                32,
                directions[i][0] * 16,
                directions[i][1] * 16,
                color="red",
                length_includes_head=True,
                head_width=2.0,
                head_length=1.0,
            )
        axs[i % rows][i // rows].imshow(image, aspect="auto")
    plt.subplots_adjust(hspace=0, wspace=0)
    if title is not None:
        fig.suptitle(title, fontsize=12)
    if show:
        plt.show()
    return fig


def show_save(save_path, show=True, save=False):
    if show:
        plt.show()
    if save:
        plt.savefig(save_path)


def color_tensor(tensor: torch.Tensor, cmap, norm=False):
    if norm:
        tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min())
    map = plt.cm.get_cmap(cmap)
    # tensor = torch.tensor(map(tensor.cpu().numpy()), device=tensor.device)[..., :3]       ## default
    tensor = torch.tensor(map(tensor.cpu().numpy()))[
        ..., :3
    ]  ## This is when the input tensor is numpy array already
    return tensor