|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.nn.functional import binary_cross_entropy_with_logits |
|
import math |
|
from transformers import PreTrainedModel |
|
from .configuration_flowformer import FlowformerConfig |
|
|
|
|
|
class MAB(nn.Module): |
|
""" |
|
Multihead attention Block (MAB) from https://arxiv.org/abs/1810.00825. |
|
""" |
|
def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False): |
|
super(MAB, self).__init__() |
|
|
|
self.dim_V = dim_V |
|
self.num_heads = num_heads |
|
self.fc_q = nn.Linear(dim_Q, dim_V) |
|
self.fc_k = nn.Linear(dim_K, dim_V) |
|
self.fc_v = nn.Linear(dim_K, dim_V) |
|
|
|
if ln: |
|
self.ln0 = nn.LayerNorm(dim_V) |
|
self.ln1 = nn.LayerNorm(dim_V) |
|
self.fc_o = nn.Linear(dim_V, dim_V) |
|
|
|
def forward(self, Q, K): |
|
Q = self.fc_q(Q) |
|
K, V = self.fc_k(K), self.fc_v(K) |
|
|
|
dim_split = self.dim_V // self.num_heads |
|
Q_ = torch.cat(Q.split(dim_split, 2), dim=0) |
|
K_ = torch.cat(K.split(dim_split, 2), dim=0) |
|
V_ = torch.cat(V.split(dim_split, 2), dim=0) |
|
|
|
A = torch.softmax(Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V), 2) |
|
O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2) |
|
O = O if getattr(self, 'ln0', None) is None else self.ln0(O) |
|
O = O + F.relu(self.fc_o(O)) |
|
O = O if getattr(self, 'ln1', None) is None else self.ln1(O) |
|
|
|
return O |
|
|
|
|
|
class ISAB(nn.Module): |
|
""" |
|
The Induced Set Attention Block (ISAB) from https://arxiv.org/abs/1810.00825. |
|
""" |
|
def __init__(self, dim_in, dim_out, num_heads, num_inds, ln=False): |
|
super(ISAB, self).__init__() |
|
|
|
self.I = nn.Parameter(torch.Tensor(1, num_inds, dim_out)) |
|
nn.init.xavier_uniform_(self.I) |
|
self.mab0 = MAB(dim_out, dim_in, dim_out, num_heads, ln=ln) |
|
self.mab1 = MAB(dim_in, dim_out, dim_out, num_heads, ln=ln) |
|
|
|
def forward(self, X): |
|
H = self.mab0(self.I.repeat(X.size(0), 1, 1), X) |
|
|
|
return self.mab1(X, H) |
|
|
|
class Flowformer(PreTrainedModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
|
|
dim_input = config.dim_input |
|
dim_hidden = config.dim_hidden |
|
num_heads = config.num_heads |
|
num_inds = config.num_inds |
|
hidden_layers = config.hidden_layers |
|
layer_norm = config.layer_norm |
|
dim_output = 1 |
|
self._pretrained_markers = config.markers or ["TIME", "FSC-A", "FSC-W", "SSC-A", "CD20", "CD10", "CD45", "CD34", "CD19", "CD38", "SY41"] |
|
|
|
|
|
enc_layers = [ISAB(dim_input, dim_hidden, num_heads, num_inds, ln=layer_norm)] |
|
for _ in range(1, hidden_layers): |
|
enc_layers.append(ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=layer_norm)) |
|
enc_layers.append(ISAB(dim_hidden, dim_input, 1, num_inds, ln=layer_norm)) |
|
self.enc = nn.Sequential(*enc_layers) |
|
|
|
|
|
dec_layers = [nn.Linear(dim_input, dim_output)] |
|
self.dec = nn.Sequential(*dec_layers) |
|
|
|
def pretrained_markers(self): |
|
return self._pretrained_markers |
|
|
|
def forward(self, tensor, labels=None, markers: list=None): |
|
B, L, M = tensor.shape |
|
if markers is not None: |
|
assert len(markers) == M, "Number of markers in x and markers must be identical" |
|
|
|
zeros = torch.zeros((B, L, len(self._pretrained_markers)), device=tensor.device) |
|
valid_markers = [m for m in markers if m in set(self._pretrained_markers).intersection(markers)] |
|
idx = [self._pretrained_markers.index(m) for m in valid_markers] |
|
zeros[:, :, idx] = tensor |
|
tensor = zeros |
|
|
|
enc_out = self.enc(tensor) |
|
output = self.dec(enc_out)[:,:,0] |
|
|
|
if labels is not None: |
|
return { |
|
'loss': binary_cross_entropy_with_logits(output, labels), |
|
'logits': output |
|
} |
|
else: |
|
return { |
|
'logits': output |
|
} |
|
|