import math import torch import torch.nn as nn # from torch.nn import Module # # for gzlabel contable_gpu env # class MultiheadAttention(Module): # r"""Allows the model to jointly attend to information # from different representation subspaces. # See `Attention Is All You Need `_. # # .. math:: # \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O # # where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`. # # Args: # embed_dim: Total dimension of the model. # num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split # across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``). # dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout). # bias: If specified, adds bias to input / output projection layers. Default: ``True``. # add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``. # add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1. # Default: ``False``. # kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``). # vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``). # batch_first: If ``True``, then the input and output tensors are provided # as (batch, seq, feature). Default: ``False`` (seq, batch, feature). # # Examples:: # # >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) # >>> attn_output, attn_output_weights = multihead_attn(query, key, value) # """ # __constants__ = ['batch_first'] # bias_k: Optional[torch.Tensor] # bias_v: Optional[torch.Tensor] # # def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, # kdim=None, vdim=None, batch_first=False, device=None, dtype=None) -> None: # factory_kwargs = {'device': device, 'dtype': dtype} # super(MultiheadAttention, self).__init__() # self.embed_dim = embed_dim # self.kdim = kdim if kdim is not None else embed_dim # self.vdim = vdim if vdim is not None else embed_dim # self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim # # self.num_heads = num_heads # self.dropout = dropout # self.batch_first = batch_first # self.head_dim = embed_dim // num_heads # assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" # # if self._qkv_same_embed_dim is False: # self.q_proj_weight = Parameter(torch.empty((embed_dim, embed_dim), **factory_kwargs)) # self.k_proj_weight = Parameter(torch.empty((embed_dim, self.kdim), **factory_kwargs)) # self.v_proj_weight = Parameter(torch.empty((embed_dim, self.vdim), **factory_kwargs)) # self.register_parameter('in_proj_weight', None) # else: # self.in_proj_weight = Parameter(torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)) # self.register_parameter('q_proj_weight', None) # self.register_parameter('k_proj_weight', None) # self.register_parameter('v_proj_weight', None) # # if bias: # self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs)) # else: # self.register_parameter('in_proj_bias', None) # self.out_proj = NonDynamicallyQuantizableLinear(embed_dim, embed_dim, bias=bias, **factory_kwargs) # # if add_bias_kv: # self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) # self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) # else: # self.bias_k = self.bias_v = None # # self.add_zero_attn = add_zero_attn # # self._reset_parameters() # # def _reset_parameters(self): # if self._qkv_same_embed_dim: # xavier_uniform_(self.in_proj_weight) # else: # xavier_uniform_(self.q_proj_weight) # xavier_uniform_(self.k_proj_weight) # xavier_uniform_(self.v_proj_weight) # # if self.in_proj_bias is not None: # constant_(self.in_proj_bias, 0.) # constant_(self.out_proj.bias, 0.) # if self.bias_k is not None: # xavier_normal_(self.bias_k) # if self.bias_v is not None: # xavier_normal_(self.bias_v) # # def __setstate__(self, state): # # Support loading old MultiheadAttention checkpoints generated by v1.1.0 # if '_qkv_same_embed_dim' not in state: # state['_qkv_same_embed_dim'] = True # # super(MultiheadAttention, self).__setstate__(state) # # def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None, # need_weights: bool = True, attn_mask: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]: # r""" # Args: # query: Query embeddings of shape :math:`(L, N, E_q)` when ``batch_first=False`` or :math:`(N, L, E_q)` # when ``batch_first=True``, where :math:`L` is the target sequence length, :math:`N` is the batch size, # and :math:`E_q` is the query embedding dimension ``embed_dim``. Queries are compared against # key-value pairs to produce the output. See "Attention Is All You Need" for more details. # key: Key embeddings of shape :math:`(S, N, E_k)` when ``batch_first=False`` or :math:`(N, S, E_k)` when # ``batch_first=True``, where :math:`S` is the source sequence length, :math:`N` is the batch size, and # :math:`E_k` is the key embedding dimension ``kdim``. See "Attention Is All You Need" for more details. # value: Value embeddings of shape :math:`(S, N, E_v)` when ``batch_first=False`` or :math:`(N, S, E_v)` when # ``batch_first=True``, where :math:`S` is the source sequence length, :math:`N` is the batch size, and # :math:`E_v` is the value embedding dimension ``vdim``. See "Attention Is All You Need" for more details. # key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key`` # to ignore for the purpose of attention (i.e. treat as "padding"). Binary and byte masks are supported. # For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for # the purpose of attention. For a byte mask, a non-zero value indicates that the corresponding ``key`` # value will be ignored. # need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``. # Default: ``True``. # attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape # :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size, # :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be # broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. # Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the # corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the # corresponding position is not allowed to attend. For a float mask, the mask values will be added to # the attention weight. # # Outputs: # - **attn_output** - Attention outputs of shape :math:`(L, N, E)` when ``batch_first=False`` or # :math:`(N, L, E)` when ``batch_first=True``, where :math:`L` is the target sequence length, :math:`N` is # the batch size, and :math:`E` is the embedding dimension ``embed_dim``. # - **attn_output_weights** - Attention output weights of shape :math:`(N, L, S)`, where :math:`N` is the batch # size, :math:`L` is the target sequence length, and :math:`S` is the source sequence length. Only returned # when ``need_weights=True``. # """ # if self.batch_first: # query, key, value = [x.transpose(1, 0) for x in (query, key, value)] # # if not self._qkv_same_embed_dim: # attn_output, attn_output_weights = F.multi_head_attention_forward( # query, key, value, self.embed_dim, self.num_heads, # self.in_proj_weight, self.in_proj_bias, # self.bias_k, self.bias_v, self.add_zero_attn, # self.dropout, self.out_proj.weight, self.out_proj.bias, # training=self.training, # key_padding_mask=key_padding_mask, need_weights=need_weights, # attn_mask=attn_mask, use_separate_proj_weight=True, # q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, # v_proj_weight=self.v_proj_weight) # else: # attn_output, attn_output_weights = F.multi_head_attention_forward( # query, key, value, self.embed_dim, self.num_heads, # self.in_proj_weight, self.in_proj_bias, # self.bias_k, self.bias_v, self.add_zero_attn, # self.dropout, self.out_proj.weight, self.out_proj.bias, # training=self.training, # key_padding_mask=key_padding_mask, need_weights=need_weights, # attn_mask=attn_mask) # if self.batch_first: # return attn_output.transpose(1, 0), attn_output_weights # else: # return attn_output, attn_output_weights class PositionalEncoding(nn.Module): "Implement the PE function." def __init__(self, d_model, dropout, max_len=5000): #d_model=512,dropout=0.1, #max_len=5000代表事先准备好长度为5000的序列的位置编码,其实没必要, #一般100或者200足够了。 super(PositionalEncoding, self).__init__() self.dropout = nn.Dropout(p=dropout) # Compute the positional encodings once in log space. pe = torch.zeros(max_len, d_model) #(5000,512)矩阵,保持每个位置的位置编码,一共5000个位置, #每个位置用一个512维度向量来表示其位置编码 position = torch.arange(0, max_len).unsqueeze(1) # (5000) -> (5000,1) div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)) # (0,2,…, 4998)一共准备2500个值,供sin, cos调用 pe[:, 0::2] = torch.sin(position * div_term) # 偶数下标的位置 pe[:, 1::2] = torch.cos(position * div_term) # 奇数下标的位置 pe = pe.unsqueeze(0) # (5000, 512) -> (1, 5000, 512) 为batch.size留出位置 self.register_buffer('pe', pe) def forward(self, x): x = x + self.pe[:, :x.size(1)] # 接受1.Embeddings的词嵌入结果x, #然后把自己的位置编码pe,封装成torch的Variable(不需要梯度),加上去。 #例如,假设x是(30,10,512)的一个tensor, #30是batch.size, 10是该batch的序列长度, 512是每个词的词嵌入向量; #则该行代码的第二项是(1, min(10, 5000), 512)=(1,10,512), #在具体相加的时候,会扩展(1,10,512)为(30,10,512), #保证一个batch中的30个序列,都使用(叠加)一样的位置编码。 return self.dropout(x) # 增加一次dropout操作 # 注意,位置编码不会更新,是写死的,所以这个class里面没有可训练的参数。 class TwoTrackAttention(nn.Module): def __init__(self, d_attn, n_head, d_ff=512, dropout=0.1) -> None: super().__init__() self.self_attn = torch.nn.MultiheadAttention( d_attn, n_head, dropout = dropout, batch_first=True # gzbl 这边的pytorch版本没有这个参数 ) self.dropout_self = nn.Dropout(dropout) self.cross_attn = torch.nn.MultiheadAttention( d_attn, n_head, dropout = dropout, batch_first=True ) self.dropout_cross = nn.Dropout(dropout) self.norm1 = nn.LayerNorm(d_attn) self.ff1 = nn.Linear(d_attn, d_ff) self.dropout_ff = nn.Dropout(dropout) self.ff2 = nn.Linear(d_ff, d_attn) self.norm2 = nn.LayerNorm(d_attn) self.dropout = nn.Dropout(dropout) self.activation = nn.ReLU() # self.s_query = nn.Linear(d_attn,d_attn) # self.s_key = nn.Linear(d_attn,d_attn) # self.s_value = nn.Linear(d_attn,d_attn) # # self.c_query = nn.Linear(d_attn,d_attn) # self.c_key = nn.Linear(d_attn,d_attn) # self.c_value = nn.Linear(d_attn,d_attn) def forward(self, obj_update, obj_message): self_update = self.self_attn( query = obj_update, key = obj_update, value = obj_update )[0] cross_update = self.cross_attn( query = obj_update, # [1, 299, 128] key = obj_message, # [1, 74, 128] value = obj_message # [1, 74, 128] )[0] # [torch.Size([1, 299, 128]), torch.Size([1, 74, 128]), torch.Size([1, 74, 128])] obj_update = obj_update + self.dropout_self(self_update) + self.dropout_cross(cross_update) obj_update = self.norm1(obj_update) ff_update = self.ff2(self.dropout_ff(self.activation(self.ff1(obj_update)))) obj_update = obj_update + self.dropout(ff_update) obj_update = self.norm2(obj_update) return obj_update class SymertricTwoTrackAttention(nn.Module): def __init__(self, d_attn, n_head, d_ff=512, dropout=0.1,sync = False) -> None: super().__init__() self.tta1 = TwoTrackAttention(d_attn, n_head, d_ff, dropout) self.tta2 = TwoTrackAttention(d_attn, n_head, d_ff, dropout) self.sync = sync def forward(self, obj_1, obj_2): if self.sync: return self.tta1(obj_1, obj_2), self.tta2(obj_2, obj_1) else: obj_1 = self.tta1(obj_1, obj_2) obj_2 = self.tta2(obj_2, obj_1) return obj_1, obj_2 class LinearFF(nn.Module): def __init__(self, d_in, d_out, dropout=0.1) -> None: super().__init__() self.emb = nn.Linear(d_in, d_out) self.norm = nn.LayerNorm(d_out) self.dropout = nn.Dropout(dropout) self.activation = nn.ReLU() def forward(self, f_in): f_in = f_in.permute(0,2,1) return self.norm(self.dropout(self.activation(self.emb(f_in)))) class ProteinRNAInteraction(nn.Module): def __init__(self, d_pro, d_rna, n_layers, d_attn, n_head=4, d_ff=512, dropout=0.1,sync=False) -> None: super().__init__() print('sync update ProteinRNAInteraction',sync) self.pro_emb = LinearFF(d_pro, d_attn) self.pro_rna = LinearFF(d_rna, d_attn) self.pro_pos = PositionalEncoding(d_attn,dropout) self.rna_pos = PositionalEncoding(d_attn,dropout) self.layers = nn.ModuleList([ SymertricTwoTrackAttention(d_attn, n_head, d_ff, dropout,sync = sync) for _ in range(n_layers) ]) self.pred = nn.Linear(d_attn, 1) # self.pred = nn.Linear(2*d_attn, 1) self.sigmoid = nn.Sigmoid() def forward(self, f_pro, f_rna): # print(f_pro.shape) # print(f_pro.device) f_pro = self.pro_emb(f_pro) f_rna = self.pro_rna(f_rna) f_pro = self.pro_pos(f_pro) f_rna = self.rna_pos(f_rna) for layer in self.layers: f_pro, f_rna = layer(f_pro, f_rna) f_pro = f_pro.unsqueeze(2) # [B, L, R, D] f_rna = f_rna.unsqueeze(1) prob = self.sigmoid(self.pred(f_rna.mul(f_pro))) return prob # f_pro = f_pro.unsqueeze(2) # [1, 299, 1, 128] # f_rna = f_rna.unsqueeze(1) # [1, 1, 74, 128] # f_pro = f_pro.repeat(1, 1, f_rna.shape[2], 1) # [B, L, R, D] # f_rna = f_rna.repeat(1, f_pro.shape[1], 1, 1) # [B, L, R, D] # # # prob = self.pred(f_rna.mul(f_pro)) # prob = self.pred(torch.cat([f_pro, f_rna], -1)) # # print(prob.max(),prob.min(),prob.mean()) # prob = torch.sigmoid(prob) # # prob = self.sigmoid(prob) # # prob = self.sigmoid(self.pred(torch.cat([f_pro, f_rna], -1))) # pred : -0.06, 0.619 # return prob