File size: 1,218 Bytes
9e15541
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
from torch import nn
import torch.nn.functional as F


class NoDimReduction(nn.Module):
    def __init__(self, full_channels, reduced_channels):
        super().__init__()
        assert full_channels == reduced_channels

    def forward(self, features):
        return features


class MlpDimReduction(nn.Module):
    def __init__(self, full_channels, reduced_channels, latent_channels):
        super().__init__()
        self.linear_in = nn.Linear(reduced_channels, latent_channels)
        self.linear_out = nn.Linear(latent_channels, full_channels)
        self.relu = nn.ReLU()

    def transform_expand(self, features):
        latent = self.relu(self.linear_in(features))
        output = self.linear_out(latent)
        return F.normalize(output, dim=-1)


class OrthogonalLinearDimReduction(nn.Module):
    def __init__(self, full_channels, reduced_channels):
        super().__init__()
        self.bias = torch.nn.Parameter(torch.zeros(full_channels))
        self.weights = torch.nn.Parameter(torch.eye(full_channels, reduced_channels))

    def transform_expand(self, features):
        output = features @ self.weights.transpose(0, 1) + self.bias
        return F.normalize(output, dim=-1)