File size: 4,207 Bytes
3f7c489
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import numpy as np
import torchvision.utils as vutils


def get_np_imgrid(array, nrow=3, padding=0, pad_value=0):
    '''
    achieves the same function of torchvision.utils.make_grid for
    numpy array
    '''
    # assume every image has smae size
    n, h, w, c = array.shape
    row_num = n // nrow + (n % nrow != 0)
    gh, gw = row_num*h + padding*(row_num-1), nrow*w + padding*(nrow - 1)
    grid = np.ones((gh, gw, c), dtype=array.dtype) * pad_value
    for i in range(n):
        grow, gcol = i // nrow, i % nrow
        off_y, off_x = grow * (h + padding), gcol * (w + padding)
        grid[off_y : off_y + h, off_x : off_x + w] = array[i]
    return grid


def split_np_imgrid(imgrid, nimg, nrow, padding=0):
    '''
    reverse operation of make_grid.
    args:
        imgrid: HWC image grid
        nimg: number of images in the grid
        nrow: number of columns in image grid
    return:
        images: list, contains splitted images
    '''
    row_num = nimg // nrow + (nimg % nrow != 0)
    gh, gw, _ = imgrid.shape
    h, w = (gh - (row_num-1)*padding)//row_num, (gw - (nrow-1)*padding)//nrow
    images = []
    for gid in range(nimg):
        grow, gcol = gid // nrow, gid % nrow 
        off_i, off_j = grow * (h + padding), gcol * (w + padding)
        images.append(imgrid[off_i:off_i+h, off_j:off_j+w])
    return images


class MDTableConvertor:
    
    def __init__(self, col_num):
        self.col_num = col_num
        
    def _get_table_row(self, items):
        row = ''
        for item in items:
            row += '| {:s} '.format(item)
        row += '|\n'
        return row

    def convert(self, item_list, title=None):
        '''
        args: 
            item_list: a list of items (str or can be converted to str)
            that want to be presented in table.

            title: None, or a list of strings. When set to None, empty title
            row is used and column number is determined by col_num; Otherwise, 
            it will be used as title row, its length will override col_num.

        return: 
            table: markdown table string.
        '''
        table = ''
        if title: # not None or not []  both equal to true
            col_num = len(title)
            table += self._get_table_row(title)
        else:
            col_num=self.col_num
            table += self._get_table_row([' ']*col_num) # empty title row
        table += self._get_table_row(['-'] * col_num) # header spliter
        for i in range(0, len(item_list), col_num):
            table += self._get_table_row(item_list[i:i+col_num])
        return table
    

def visual_dict_to_imgrid(visual_dict, col_num=4, padding=0):
    '''
    args:
        visual_dict: a dictionary of images of the same size
        col_num: number of columns in image grid
        padding: number of padding pixels to seperate images
    '''
    im_names = []
    im_tensors = []
    for name, visual in visual_dict.items():
        im_names.append(name)
        im_tensors.append(visual)
    im_grid = vutils.make_grid(im_tensors,
                               nrow=col_num ,
                               padding=0,
                               pad_value=1.0)
    layout = MDTableConvertor(col_num).convert(im_names)
    
    return im_grid, layout


def count_parameters(model, trainable_only=False):
    return sum(p.numel() for p in model.parameters())
    
    

class WarmupExpLRScheduler(object):
    def __init__(self, lr_start=1e-4, lr_max=4e-4, lr_min=5e-6, rampup_epochs=4, sustain_epochs=0, exp_decay=0.75):
        self.lr_start = lr_start
        self.lr_max = lr_max
        self.lr_min = lr_min
        self.rampup_epochs = rampup_epochs
        self.sustain_epochs = sustain_epochs
        self.exp_decay = exp_decay
    
    def __call__(self, epoch):
        if epoch < self.rampup_epochs:
            lr = (self.lr_max - self.lr_start) / self.rampup_epochs * epoch + self.lr_start
        elif epoch < self.rampup_epochs + self.sustain_epochs:
            lr = self.lr_max
        else:
            lr = (self.lr_max - self.lr_min) * self.exp_decay**(epoch - self.rampup_epochs - self.sustain_epochs) + self.lr_min
        # print(lr)
        return lr