File size: 2,645 Bytes
e75a247
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import torch



class LLFillSpace(torch.nn.Module):
    def __init__(self,
                 maxhits: int = 1000,
                 runevery: int = -1):
        #print('INFO: LLFillSpace: this is actually a regulariser: move to right file soon.')
        assert maxhits > 0
        self.maxhits = maxhits
        self.runevery = runevery
        self.counter = -1
        if runevery < 0:
            self.counter = -2
        super(LLFillSpace, self).__init__()

    def get_config(self):
        config = {'maxhits': self.maxhits,
                  'runevery': self.runevery}
        base_config = super(LLFillSpace, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def _rs_loop(self, coords):
        # only select a few hits to keep memory managable
        nhits = coords.shape[0]
        maxhits = self.maxhits
        sel = None
        if nhits > maxhits:
            sel = torch.randint(low=0, high=coords.shape[0] - 1, size=(maxhits,), dtype=torch.int32)
        else:
            sel = torch.arange(coords.shape[0], dtype=torch.int32)
        sel = sel.to(coords.device)
        sel = torch.unsqueeze(sel, dim=1).flatten()
        coords_selected = torch.index_select(coords, 0, sel).clone()  # V' x C
        # print('coords',coords.shape)
        means = torch.mean(coords_selected, axis=0)  # 1 x C
        coords_selected = coords_selected - means  # V' x C
        # build covariance
        cov = torch.unsqueeze(coords_selected, dim=1) * torch.unsqueeze(coords_selected, dim=2)
        cov = torch.mean(cov, dim=0, keepdim=False)  # 1 x C x C
        # print('cov',cov)
        # get eigenvals
        eigenvals, _ = torch.linalg.eig(cov)  # cheap because just once, no need for approx
        eigenvals = eigenvals.to(torch.float32)
        # penalise one small EV (e.g. when building a surface)
        pen = torch.log((torch.mean(eigenvals) / (torch.min(eigenvals) + 1e-6) - 1.) ** 2 + 1.)
        return pen

    def _raw_loss(self, coords, batch_idx):
        loss = torch.tensor(0).float().to(coords.device)
        for i in batch_idx.unique():
            idx = batch_idx == i
            loss += self._rs_loop(coords[idx, :])
        return loss

    def forward(self, clust_space, batch_idx):
        if self.counter >= 0:  # completely optimise away increment
            if self.counter < self.runevery:
                self.counter += 1
                return torch.tensor(0).to(clust_space.device)
            self.counter = 0
        lossval = self._raw_loss(clust_space, batch_idx)

        if self.counter == -1:
            self.counter += 1
        return lossval