File size: 3,105 Bytes
4304c2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# Modified partialconv source code based on implementation from
# https://github.com/NVIDIA/partialconv/blob/master/models/partialconv2d.py
###############################################################################
# BSD 3-Clause License
#
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Author & Contact: Guilin Liu ([email protected])
###############################################################################

# Original Author & Contact: Guilin Liu ([email protected])
# Modified by Kevin Shih ([email protected])

import torch
import torch.nn.functional as F
from torch import nn


class PartialConv1d(nn.Conv1d):
    def __init__(self, *args, **kwargs):
        self.multi_channel = False
        self.return_mask = False
        super(PartialConv1d, self).__init__(*args, **kwargs)

        self.weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0])
        self.slide_winsize = (
            self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2]
        )

        self.last_size = (None, None, None)
        self.update_mask = None
        self.mask_ratio = None

    @torch.jit.ignore
    def forward(self, input: torch.Tensor, mask_in: torch.Tensor = None):
        """
        input: standard input to a 1D conv
        mask_in: binary mask for valid values, same shape as input
        """
        assert len(input.shape) == 3
        # if a mask is input, or tensor shape changed, update mask ratio
        if mask_in is not None or self.last_size != tuple(input.shape):
            self.last_size = tuple(input.shape)
            with torch.no_grad():
                if self.weight_maskUpdater.type() != input.type():
                    self.weight_maskUpdater = self.weight_maskUpdater.to(input)
                if mask_in is None:
                    mask = torch.ones(1, 1, input.data.shape[2]).to(input)
                else:
                    mask = mask_in
                self.update_mask = F.conv1d(
                    mask,
                    self.weight_maskUpdater,
                    bias=None,
                    stride=self.stride,
                    padding=self.padding,
                    dilation=self.dilation,
                    groups=1,
                )
                # for mixed precision training, change 1e-8 to 1e-6
                self.mask_ratio = self.slide_winsize / (self.update_mask + 1e-6)
                self.update_mask = torch.clamp(self.update_mask, 0, 1)
                self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask)
        raw_out = super(PartialConv1d, self).forward(
            torch.mul(input, mask) if mask_in is not None else input
        )
        if self.bias is not None:
            bias_view = self.bias.view(1, self.out_channels, 1)
            output = torch.mul(raw_out - bias_view, self.mask_ratio) + bias_view
            output = torch.mul(output, self.update_mask)
        else:
            output = torch.mul(raw_out, self.mask_ratio)

        if self.return_mask:
            return output, self.update_mask
        else:
            return output