Spaces:
Sleeping
Sleeping
File size: 2,645 Bytes
e75a247 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import torch
class LLFillSpace(torch.nn.Module):
def __init__(self,
maxhits: int = 1000,
runevery: int = -1):
#print('INFO: LLFillSpace: this is actually a regulariser: move to right file soon.')
assert maxhits > 0
self.maxhits = maxhits
self.runevery = runevery
self.counter = -1
if runevery < 0:
self.counter = -2
super(LLFillSpace, self).__init__()
def get_config(self):
config = {'maxhits': self.maxhits,
'runevery': self.runevery}
base_config = super(LLFillSpace, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def _rs_loop(self, coords):
# only select a few hits to keep memory managable
nhits = coords.shape[0]
maxhits = self.maxhits
sel = None
if nhits > maxhits:
sel = torch.randint(low=0, high=coords.shape[0] - 1, size=(maxhits,), dtype=torch.int32)
else:
sel = torch.arange(coords.shape[0], dtype=torch.int32)
sel = sel.to(coords.device)
sel = torch.unsqueeze(sel, dim=1).flatten()
coords_selected = torch.index_select(coords, 0, sel).clone() # V' x C
# print('coords',coords.shape)
means = torch.mean(coords_selected, axis=0) # 1 x C
coords_selected = coords_selected - means # V' x C
# build covariance
cov = torch.unsqueeze(coords_selected, dim=1) * torch.unsqueeze(coords_selected, dim=2)
cov = torch.mean(cov, dim=0, keepdim=False) # 1 x C x C
# print('cov',cov)
# get eigenvals
eigenvals, _ = torch.linalg.eig(cov) # cheap because just once, no need for approx
eigenvals = eigenvals.to(torch.float32)
# penalise one small EV (e.g. when building a surface)
pen = torch.log((torch.mean(eigenvals) / (torch.min(eigenvals) + 1e-6) - 1.) ** 2 + 1.)
return pen
def _raw_loss(self, coords, batch_idx):
loss = torch.tensor(0).float().to(coords.device)
for i in batch_idx.unique():
idx = batch_idx == i
loss += self._rs_loop(coords[idx, :])
return loss
def forward(self, clust_space, batch_idx):
if self.counter >= 0: # completely optimise away increment
if self.counter < self.runevery:
self.counter += 1
return torch.tensor(0).to(clust_space.device)
self.counter = 0
lossval = self._raw_loss(clust_space, batch_idx)
if self.counter == -1:
self.counter += 1
return lossval
|