julse's picture
Upload 23 files
82d55c6 verified
import math
import torch
import torch.nn as nn
# from torch.nn import Module
# # for gzlabel contable_gpu env
# class MultiheadAttention(Module):
# r"""Allows the model to jointly attend to information
# from different representation subspaces.
# See `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
#
# .. math::
# \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
#
# where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
#
# Args:
# embed_dim: Total dimension of the model.
# num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
# across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
# dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
# bias: If specified, adds bias to input / output projection layers. Default: ``True``.
# add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
# add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
# Default: ``False``.
# kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
# vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
# batch_first: If ``True``, then the input and output tensors are provided
# as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
#
# Examples::
#
# >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
# >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
# """
# __constants__ = ['batch_first']
# bias_k: Optional[torch.Tensor]
# bias_v: Optional[torch.Tensor]
#
# def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False,
# kdim=None, vdim=None, batch_first=False, device=None, dtype=None) -> None:
# factory_kwargs = {'device': device, 'dtype': dtype}
# super(MultiheadAttention, self).__init__()
# self.embed_dim = embed_dim
# self.kdim = kdim if kdim is not None else embed_dim
# self.vdim = vdim if vdim is not None else embed_dim
# self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
#
# self.num_heads = num_heads
# self.dropout = dropout
# self.batch_first = batch_first
# self.head_dim = embed_dim // num_heads
# assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
#
# if self._qkv_same_embed_dim is False:
# self.q_proj_weight = Parameter(torch.empty((embed_dim, embed_dim), **factory_kwargs))
# self.k_proj_weight = Parameter(torch.empty((embed_dim, self.kdim), **factory_kwargs))
# self.v_proj_weight = Parameter(torch.empty((embed_dim, self.vdim), **factory_kwargs))
# self.register_parameter('in_proj_weight', None)
# else:
# self.in_proj_weight = Parameter(torch.empty((3 * embed_dim, embed_dim), **factory_kwargs))
# self.register_parameter('q_proj_weight', None)
# self.register_parameter('k_proj_weight', None)
# self.register_parameter('v_proj_weight', None)
#
# if bias:
# self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs))
# else:
# self.register_parameter('in_proj_bias', None)
# self.out_proj = NonDynamicallyQuantizableLinear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
#
# if add_bias_kv:
# self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
# self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
# else:
# self.bias_k = self.bias_v = None
#
# self.add_zero_attn = add_zero_attn
#
# self._reset_parameters()
#
# def _reset_parameters(self):
# if self._qkv_same_embed_dim:
# xavier_uniform_(self.in_proj_weight)
# else:
# xavier_uniform_(self.q_proj_weight)
# xavier_uniform_(self.k_proj_weight)
# xavier_uniform_(self.v_proj_weight)
#
# if self.in_proj_bias is not None:
# constant_(self.in_proj_bias, 0.)
# constant_(self.out_proj.bias, 0.)
# if self.bias_k is not None:
# xavier_normal_(self.bias_k)
# if self.bias_v is not None:
# xavier_normal_(self.bias_v)
#
# def __setstate__(self, state):
# # Support loading old MultiheadAttention checkpoints generated by v1.1.0
# if '_qkv_same_embed_dim' not in state:
# state['_qkv_same_embed_dim'] = True
#
# super(MultiheadAttention, self).__setstate__(state)
#
# def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None,
# need_weights: bool = True, attn_mask: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
# r"""
# Args:
# query: Query embeddings of shape :math:`(L, N, E_q)` when ``batch_first=False`` or :math:`(N, L, E_q)`
# when ``batch_first=True``, where :math:`L` is the target sequence length, :math:`N` is the batch size,
# and :math:`E_q` is the query embedding dimension ``embed_dim``. Queries are compared against
# key-value pairs to produce the output. See "Attention Is All You Need" for more details.
# key: Key embeddings of shape :math:`(S, N, E_k)` when ``batch_first=False`` or :math:`(N, S, E_k)` when
# ``batch_first=True``, where :math:`S` is the source sequence length, :math:`N` is the batch size, and
# :math:`E_k` is the key embedding dimension ``kdim``. See "Attention Is All You Need" for more details.
# value: Value embeddings of shape :math:`(S, N, E_v)` when ``batch_first=False`` or :math:`(N, S, E_v)` when
# ``batch_first=True``, where :math:`S` is the source sequence length, :math:`N` is the batch size, and
# :math:`E_v` is the value embedding dimension ``vdim``. See "Attention Is All You Need" for more details.
# key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
# to ignore for the purpose of attention (i.e. treat as "padding"). Binary and byte masks are supported.
# For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
# the purpose of attention. For a byte mask, a non-zero value indicates that the corresponding ``key``
# value will be ignored.
# need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
# Default: ``True``.
# attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
# :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
# :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
# broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
# Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the
# corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the
# corresponding position is not allowed to attend. For a float mask, the mask values will be added to
# the attention weight.
#
# Outputs:
# - **attn_output** - Attention outputs of shape :math:`(L, N, E)` when ``batch_first=False`` or
# :math:`(N, L, E)` when ``batch_first=True``, where :math:`L` is the target sequence length, :math:`N` is
# the batch size, and :math:`E` is the embedding dimension ``embed_dim``.
# - **attn_output_weights** - Attention output weights of shape :math:`(N, L, S)`, where :math:`N` is the batch
# size, :math:`L` is the target sequence length, and :math:`S` is the source sequence length. Only returned
# when ``need_weights=True``.
# """
# if self.batch_first:
# query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
#
# if not self._qkv_same_embed_dim:
# attn_output, attn_output_weights = F.multi_head_attention_forward(
# query, key, value, self.embed_dim, self.num_heads,
# self.in_proj_weight, self.in_proj_bias,
# self.bias_k, self.bias_v, self.add_zero_attn,
# self.dropout, self.out_proj.weight, self.out_proj.bias,
# training=self.training,
# key_padding_mask=key_padding_mask, need_weights=need_weights,
# attn_mask=attn_mask, use_separate_proj_weight=True,
# q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
# v_proj_weight=self.v_proj_weight)
# else:
# attn_output, attn_output_weights = F.multi_head_attention_forward(
# query, key, value, self.embed_dim, self.num_heads,
# self.in_proj_weight, self.in_proj_bias,
# self.bias_k, self.bias_v, self.add_zero_attn,
# self.dropout, self.out_proj.weight, self.out_proj.bias,
# training=self.training,
# key_padding_mask=key_padding_mask, need_weights=need_weights,
# attn_mask=attn_mask)
# if self.batch_first:
# return attn_output.transpose(1, 0), attn_output_weights
# else:
# return attn_output, attn_output_weights
class PositionalEncoding(nn.Module):
"Implement the PE function."
def __init__(self, d_model, dropout, max_len=5000):
#d_model=512,dropout=0.1,
#max_len=5000代表事先准备好长度为5000的序列的位置编码,其实没必要,
#一般100或者200足够了。
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
# Compute the positional encodings once in log space.
pe = torch.zeros(max_len, d_model)
#(5000,512)矩阵,保持每个位置的位置编码,一共5000个位置,
#每个位置用一个512维度向量来表示其位置编码
position = torch.arange(0, max_len).unsqueeze(1)
# (5000) -> (5000,1)
div_term = torch.exp(torch.arange(0, d_model, 2) *
-(math.log(10000.0) / d_model))
# (0,2,…, 4998)一共准备2500个值,供sin, cos调用
pe[:, 0::2] = torch.sin(position * div_term) # 偶数下标的位置
pe[:, 1::2] = torch.cos(position * div_term) # 奇数下标的位置
pe = pe.unsqueeze(0)
# (5000, 512) -> (1, 5000, 512) 为batch.size留出位置
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:, :x.size(1)]
# 接受1.Embeddings的词嵌入结果x,
#然后把自己的位置编码pe,封装成torch的Variable(不需要梯度),加上去。
#例如,假设x是(30,10,512)的一个tensor,
#30是batch.size, 10是该batch的序列长度, 512是每个词的词嵌入向量;
#则该行代码的第二项是(1, min(10, 5000), 512)=(1,10,512),
#在具体相加的时候,会扩展(1,10,512)为(30,10,512),
#保证一个batch中的30个序列,都使用(叠加)一样的位置编码。
return self.dropout(x) # 增加一次dropout操作
# 注意,位置编码不会更新,是写死的,所以这个class里面没有可训练的参数。
class TwoTrackAttention(nn.Module):
def __init__(self, d_attn, n_head, d_ff=512, dropout=0.1) -> None:
super().__init__()
self.self_attn = torch.nn.MultiheadAttention(
d_attn, n_head,
dropout = dropout,
batch_first=True # gzbl 这边的pytorch版本没有这个参数
)
self.dropout_self = nn.Dropout(dropout)
self.cross_attn = torch.nn.MultiheadAttention(
d_attn, n_head,
dropout = dropout,
batch_first=True
)
self.dropout_cross = nn.Dropout(dropout)
self.norm1 = nn.LayerNorm(d_attn)
self.ff1 = nn.Linear(d_attn, d_ff)
self.dropout_ff = nn.Dropout(dropout)
self.ff2 = nn.Linear(d_ff, d_attn)
self.norm2 = nn.LayerNorm(d_attn)
self.dropout = nn.Dropout(dropout)
self.activation = nn.ReLU()
# self.s_query = nn.Linear(d_attn,d_attn)
# self.s_key = nn.Linear(d_attn,d_attn)
# self.s_value = nn.Linear(d_attn,d_attn)
#
# self.c_query = nn.Linear(d_attn,d_attn)
# self.c_key = nn.Linear(d_attn,d_attn)
# self.c_value = nn.Linear(d_attn,d_attn)
def forward(self, obj_update, obj_message):
self_update = self.self_attn(
query = obj_update,
key = obj_update,
value = obj_update
)[0]
cross_update = self.cross_attn(
query = obj_update, # [1, 299, 128]
key = obj_message, # [1, 74, 128]
value = obj_message # [1, 74, 128]
)[0]
# [torch.Size([1, 299, 128]), torch.Size([1, 74, 128]), torch.Size([1, 74, 128])]
obj_update = obj_update + self.dropout_self(self_update) + self.dropout_cross(cross_update)
obj_update = self.norm1(obj_update)
ff_update = self.ff2(self.dropout_ff(self.activation(self.ff1(obj_update))))
obj_update = obj_update + self.dropout(ff_update)
obj_update = self.norm2(obj_update)
return obj_update
class SymertricTwoTrackAttention(nn.Module):
def __init__(self, d_attn, n_head, d_ff=512, dropout=0.1,sync = False) -> None:
super().__init__()
self.tta1 = TwoTrackAttention(d_attn, n_head, d_ff, dropout)
self.tta2 = TwoTrackAttention(d_attn, n_head, d_ff, dropout)
self.sync = sync
def forward(self, obj_1, obj_2):
if self.sync:
return self.tta1(obj_1, obj_2), self.tta2(obj_2, obj_1)
else:
obj_1 = self.tta1(obj_1, obj_2)
obj_2 = self.tta2(obj_2, obj_1)
return obj_1, obj_2
class LinearFF(nn.Module):
def __init__(self, d_in, d_out, dropout=0.1) -> None:
super().__init__()
self.emb = nn.Linear(d_in, d_out)
self.norm = nn.LayerNorm(d_out)
self.dropout = nn.Dropout(dropout)
self.activation = nn.ReLU()
def forward(self, f_in):
f_in = f_in.permute(0,2,1)
return self.norm(self.dropout(self.activation(self.emb(f_in))))
class ProteinRNAInteraction(nn.Module):
def __init__(self, d_pro, d_rna, n_layers, d_attn, n_head=4, d_ff=512, dropout=0.1,sync=False) -> None:
super().__init__()
print('sync update ProteinRNAInteraction',sync)
self.pro_emb = LinearFF(d_pro, d_attn)
self.pro_rna = LinearFF(d_rna, d_attn)
self.pro_pos = PositionalEncoding(d_attn,dropout)
self.rna_pos = PositionalEncoding(d_attn,dropout)
self.layers = nn.ModuleList([
SymertricTwoTrackAttention(d_attn, n_head, d_ff, dropout,sync = sync) for _ in range(n_layers)
])
self.pred = nn.Linear(d_attn, 1)
# self.pred = nn.Linear(2*d_attn, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, f_pro, f_rna):
# print(f_pro.shape)
# print(f_pro.device)
f_pro = self.pro_emb(f_pro)
f_rna = self.pro_rna(f_rna)
f_pro = self.pro_pos(f_pro)
f_rna = self.rna_pos(f_rna)
for layer in self.layers:
f_pro, f_rna = layer(f_pro, f_rna)
f_pro = f_pro.unsqueeze(2) # [B, L, R, D]
f_rna = f_rna.unsqueeze(1)
prob = self.sigmoid(self.pred(f_rna.mul(f_pro)))
return prob
# f_pro = f_pro.unsqueeze(2) # [1, 299, 1, 128]
# f_rna = f_rna.unsqueeze(1) # [1, 1, 74, 128]
# f_pro = f_pro.repeat(1, 1, f_rna.shape[2], 1) # [B, L, R, D]
# f_rna = f_rna.repeat(1, f_pro.shape[1], 1, 1) # [B, L, R, D]
#
# # prob = self.pred(f_rna.mul(f_pro))
# prob = self.pred(torch.cat([f_pro, f_rna], -1))
# # print(prob.max(),prob.min(),prob.mean())
# prob = torch.sigmoid(prob)
# # prob = self.sigmoid(prob)
# # prob = self.sigmoid(self.pred(torch.cat([f_pro, f_rna], -1))) # pred : -0.06, 0.619
# return prob