File size: 4,180 Bytes
f710746
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.functional import binary_cross_entropy_with_logits
import math
from transformers import PreTrainedModel
from .configuration_flowformer import FlowformerConfig


class MAB(nn.Module):
    """
    Multihead attention Block (MAB) from https://arxiv.org/abs/1810.00825.
    """
    def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False):
        super(MAB, self).__init__()

        self.dim_V = dim_V
        self.num_heads = num_heads
        self.fc_q = nn.Linear(dim_Q, dim_V)
        self.fc_k = nn.Linear(dim_K, dim_V)
        self.fc_v = nn.Linear(dim_K, dim_V)

        if ln:
            self.ln0 = nn.LayerNorm(dim_V)
            self.ln1 = nn.LayerNorm(dim_V)
        self.fc_o = nn.Linear(dim_V, dim_V)

    def forward(self, Q, K):
        Q = self.fc_q(Q)
        K, V = self.fc_k(K), self.fc_v(K)

        dim_split = self.dim_V // self.num_heads
        Q_ = torch.cat(Q.split(dim_split, 2), dim=0)
        K_ = torch.cat(K.split(dim_split, 2), dim=0)
        V_ = torch.cat(V.split(dim_split, 2), dim=0)

        A = torch.softmax(Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V), 2)
        O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2)
        O = O if getattr(self, 'ln0', None) is None else self.ln0(O)
        O = O + F.relu(self.fc_o(O))
        O = O if getattr(self, 'ln1', None) is None else self.ln1(O)

        return O


class ISAB(nn.Module):
    """
    The Induced Set Attention Block (ISAB) from https://arxiv.org/abs/1810.00825.
    """
    def __init__(self, dim_in, dim_out, num_heads, num_inds, ln=False):
        super(ISAB, self).__init__()

        self.I = nn.Parameter(torch.Tensor(1, num_inds, dim_out))
        nn.init.xavier_uniform_(self.I)
        self.mab0 = MAB(dim_out, dim_in, dim_out, num_heads, ln=ln)
        self.mab1 = MAB(dim_in, dim_out, dim_out, num_heads, ln=ln)

    def forward(self, X):
        H = self.mab0(self.I.repeat(X.size(0), 1, 1), X)

        return self.mab1(X, H)
    
class Flowformer(PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)

        # Load config
        dim_input = config.dim_input
        dim_hidden = config.dim_hidden
        num_heads = config.num_heads
        num_inds = config.num_inds
        hidden_layers = config.hidden_layers
        layer_norm = config.layer_norm
        dim_output = 1  
        self._pretrained_markers = config.markers or ["TIME", "FSC-A", "FSC-W", "SSC-A", "CD20", "CD10", "CD45", "CD34", "CD19", "CD38", "SY41"]

        # Define encoder
        enc_layers = [ISAB(dim_input, dim_hidden, num_heads, num_inds, ln=layer_norm)]
        for _ in range(1, hidden_layers):
            enc_layers.append(ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=layer_norm))
        enc_layers.append(ISAB(dim_hidden, dim_input, 1, num_inds, ln=layer_norm)) # num_heads == 1 because dim_input can be a prime number
        self.enc = nn.Sequential(*enc_layers)
 
        # Define decoder
        dec_layers = [nn.Linear(dim_input, dim_output)]
        self.dec = nn.Sequential(*dec_layers)

    def pretrained_markers(self):
        return self._pretrained_markers

    def forward(self, tensor, labels=None, markers: list=None):
        B, L, M = tensor.shape
        if markers is not None:
            assert len(markers) == M, "Number of markers in x and markers must be identical"

            zeros = torch.zeros((B, L, len(self._pretrained_markers)), device=tensor.device)
            valid_markers = [m for m in markers if m in set(self._pretrained_markers).intersection(markers)]
            idx = [self._pretrained_markers.index(m) for m in valid_markers]
            zeros[:, :, idx] = tensor # select only the markers that are in the pretrained model
            tensor = zeros

        enc_out = self.enc(tensor)        
        output = self.dec(enc_out)[:,:,0]

        if labels is not None:
            return {
                'loss': binary_cross_entropy_with_logits(output, labels),
                'logits': output
            }
        else:
            return {
                'logits': output
            }