File size: 2,611 Bytes
6c6eb37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from typing import Type, List, Union

import torch
from torch import nn as nn
from torch.nn import init as init
from torch.nn.modules.batchnorm import _BatchNorm


@torch.no_grad()
def default_init_weights(

    module_list: Union[List[nn.Module], nn.Module],

    scale: float = 1,

    bias_fill: float = 0,

    **kwargs,

) -> None:
    """Initialize network weights.



    Args:

        module_list (list[nn.Module] | nn.Module): Modules to be initialized.

        scale (float): Scale initialized weights, especially for residual

            blocks. Default: 1.

        bias_fill (float): The value to fill bias. Default: 0

        kwargs (dict): Other arguments for initialization function.

    """
    if not isinstance(module_list, list):
        module_list = [module_list]
    for module in module_list:
        for m in module.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, **kwargs)
                m.weight.data *= scale
                if m.bias is not None:
                    m.bias.data.fill_(bias_fill)
            elif isinstance(m, nn.Linear):
                init.kaiming_normal_(m.weight, **kwargs)
                m.weight.data *= scale
                if m.bias is not None:
                    m.bias.data.fill_(bias_fill)
            elif isinstance(m, _BatchNorm):
                init.constant_(m.weight, 1)
                if m.bias is not None:
                    m.bias.data.fill_(bias_fill)


def make_layer(

    basic_block: Type[nn.Module], num_basic_block: int, **kwarg

) -> nn.Sequential:
    """Make layers by stacking the same blocks.



    Args:

        basic_block (Type[nn.Module]): nn.Module class for basic block.

        num_basic_block (int): number of blocks.



    Returns:

        nn.Sequential: Stacked blocks in nn.Sequential.

    """
    layers = []
    for _ in range(num_basic_block):
        layers.append(basic_block(**kwarg))
    return nn.Sequential(*layers)


# TODO: may write a cpp file
def pixel_unshuffle(x: torch.Tensor, scale: int) -> torch.Tensor:
    """Pixel unshuffle.



    Args:

        x (Tensor): Input feature with shape (b, c, hh, hw).

        scale (int): Downsample ratio.



    Returns:

        Tensor: the pixel unshuffled feature.

    """
    b, c, hh, hw = x.size()
    out_channel = c * (scale**2)
    assert hh % scale == 0 and hw % scale == 0
    h = hh // scale
    w = hw // scale
    x_view = x.view(b, c, h, scale, w, scale)
    return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)