Spaces:
Sleeping
Sleeping
import torch | |
class LLFillSpace(torch.nn.Module): | |
def __init__(self, | |
maxhits: int = 1000, | |
runevery: int = -1): | |
#print('INFO: LLFillSpace: this is actually a regulariser: move to right file soon.') | |
assert maxhits > 0 | |
self.maxhits = maxhits | |
self.runevery = runevery | |
self.counter = -1 | |
if runevery < 0: | |
self.counter = -2 | |
super(LLFillSpace, self).__init__() | |
def get_config(self): | |
config = {'maxhits': self.maxhits, | |
'runevery': self.runevery} | |
base_config = super(LLFillSpace, self).get_config() | |
return dict(list(base_config.items()) + list(config.items())) | |
def _rs_loop(self, coords): | |
# only select a few hits to keep memory managable | |
nhits = coords.shape[0] | |
maxhits = self.maxhits | |
sel = None | |
if nhits > maxhits: | |
sel = torch.randint(low=0, high=coords.shape[0] - 1, size=(maxhits,), dtype=torch.int32) | |
else: | |
sel = torch.arange(coords.shape[0], dtype=torch.int32) | |
sel = sel.to(coords.device) | |
sel = torch.unsqueeze(sel, dim=1).flatten() | |
coords_selected = torch.index_select(coords, 0, sel).clone() # V' x C | |
# print('coords',coords.shape) | |
means = torch.mean(coords_selected, axis=0) # 1 x C | |
coords_selected = coords_selected - means # V' x C | |
# build covariance | |
cov = torch.unsqueeze(coords_selected, dim=1) * torch.unsqueeze(coords_selected, dim=2) | |
cov = torch.mean(cov, dim=0, keepdim=False) # 1 x C x C | |
# print('cov',cov) | |
# get eigenvals | |
eigenvals, _ = torch.linalg.eig(cov) # cheap because just once, no need for approx | |
eigenvals = eigenvals.to(torch.float32) | |
# penalise one small EV (e.g. when building a surface) | |
pen = torch.log((torch.mean(eigenvals) / (torch.min(eigenvals) + 1e-6) - 1.) ** 2 + 1.) | |
return pen | |
def _raw_loss(self, coords, batch_idx): | |
loss = torch.tensor(0).float().to(coords.device) | |
for i in batch_idx.unique(): | |
idx = batch_idx == i | |
loss += self._rs_loop(coords[idx, :]) | |
return loss | |
def forward(self, clust_space, batch_idx): | |
if self.counter >= 0: # completely optimise away increment | |
if self.counter < self.runevery: | |
self.counter += 1 | |
return torch.tensor(0).to(clust_space.device) | |
self.counter = 0 | |
lossval = self._raw_loss(clust_space, batch_idx) | |
if self.counter == -1: | |
self.counter += 1 | |
return lossval | |