import torch import torch.nn as nn PT_FEATURE_SIZE = 40 class DeepDTAF(nn.Module): def __init__(self, smi_charset_len): super().__init__() smi_embed_size = 128 seq_embed_size = 128 seq_oc = 128 pkt_oc = 128 smi_oc = 128 self.smi_embed = nn.Embedding(smi_charset_len, smi_embed_size) self.seq_embed = nn.Linear(PT_FEATURE_SIZE, seq_embed_size) # (N, *, H_{in}) -> (N, *, H_{out}) conv_seq = [] ic = seq_embed_size for oc in [32, 64, 64, seq_oc]: conv_seq.append(DilatedParllelResidualBlockA(ic, oc)) ic = oc conv_seq.append(nn.AdaptiveMaxPool1d(1)) # (N, oc) conv_seq.append(Squeeze()) self.conv_seq = nn.Sequential(*conv_seq) # (N, H=32, L) conv_pkt = [] ic = seq_embed_size for oc in [32, 64, pkt_oc]: conv_pkt.append(nn.Conv1d(ic, oc, 3)) # (N,C,L) conv_pkt.append(nn.BatchNorm1d(oc)) conv_pkt.append(nn.PReLU()) ic = oc conv_pkt.append(nn.AdaptiveMaxPool1d(1)) conv_pkt.append(Squeeze()) self.conv_pkt = nn.Sequential(*conv_pkt) # (N,oc) conv_smi = [] ic = smi_embed_size for oc in [32, 64, smi_oc]: conv_smi.append(DilatedParllelResidualBlockB(ic, oc)) ic = oc conv_smi.append(nn.AdaptiveMaxPool1d(1)) conv_smi.append(Squeeze()) self.conv_smi = nn.Sequential(*conv_smi) # (N,128) self.cat_dropout = nn.Dropout(0.2) self.classifier = nn.Sequential( nn.Linear(seq_oc + pkt_oc + smi_oc, 128), nn.Dropout(0.5), nn.PReLU(), nn.Linear(128, 64), nn.Dropout(0.5), nn.PReLU(), # nn.Linear(64, 1), # nn.PReLU() ) def forward(self, seq, pkt, smi): # assert seq.shape == (N,L,43) seq_embed = self.seq_embed(seq) # (N,L,32) seq_embed = torch.transpose(seq_embed, 1, 2) # (N,32,L) seq_conv = self.conv_seq(seq_embed) # (N,128) # assert pkt.shape == (N,L,43) pkt_embed = self.seq_embed(pkt) # (N,L,32) pkt_embed = torch.transpose(pkt_embed, 1, 2) pkt_conv = self.conv_pkt(pkt_embed) # (N,128) # assert smi.shape == (N, L) smi_embed = self.smi_embed(smi) # (N,L,32) smi_embed = torch.transpose(smi_embed, 1, 2) smi_conv = self.conv_smi(smi_embed) # (N,128) cat = torch.cat([seq_conv, pkt_conv, smi_conv], dim=1) # (N,128*3) cat = self.cat_dropout(cat) output = self.classifier(cat) return output class Squeeze(nn.Module): def forward(self, input: torch.Tensor): return input.squeeze() class CDilated(nn.Module): def __init__(self, nIn, nOut, kSize, stride=1, d=1): super().__init__() padding = int((kSize - 1) / 2) * d self.conv = nn.Conv1d(nIn, nOut, kSize, stride=stride, padding=padding, bias=False, dilation=d) def forward(self, input): output = self.conv(input) return output class DilatedParllelResidualBlockA(nn.Module): def __init__(self, nIn, nOut, add=True): super().__init__() n = int(nOut / 5) n1 = nOut - 4 * n self.c1 = nn.Conv1d(nIn, n, 1, padding=0) self.br1 = nn.Sequential(nn.BatchNorm1d(n), nn.PReLU()) self.d1 = CDilated(n, n1, 3, 1, 1) # dilation rate of 2^0 self.d2 = CDilated(n, n, 3, 1, 2) # dilation rate of 2^1 self.d4 = CDilated(n, n, 3, 1, 4) # dilation rate of 2^2 self.d8 = CDilated(n, n, 3, 1, 8) # dilation rate of 2^3 self.d16 = CDilated(n, n, 3, 1, 16) # dilation rate of 2^4 self.br2 = nn.Sequential(nn.BatchNorm1d(nOut), nn.PReLU()) if nIn != nOut: # print(f'{nIn}-{nOut}: add=False') add = False self.add = add def forward(self, input): # reduce output1 = self.c1(input) output1 = self.br1(output1) # split and transform d1 = self.d1(output1) d2 = self.d2(output1) d4 = self.d4(output1) d8 = self.d8(output1) d16 = self.d16(output1) # heirarchical fusion for de-gridding add1 = d2 add2 = add1 + d4 add3 = add2 + d8 add4 = add3 + d16 # merge combine = torch.cat([d1, add1, add2, add3, add4], 1) # if residual version if self.add: combine = input + combine output = self.br2(combine) return output class DilatedParllelResidualBlockB(nn.Module): def __init__(self, nIn, nOut, add=True): super().__init__() n = int(nOut / 4) n1 = nOut - 3 * n self.c1 = nn.Conv1d(nIn, n, 1, padding=0) self.br1 = nn.Sequential(nn.BatchNorm1d(n), nn.PReLU()) self.d1 = CDilated(n, n1, 3, 1, 1) # dilation rate of 2^0 self.d2 = CDilated(n, n, 3, 1, 2) # dilation rate of 2^1 self.d4 = CDilated(n, n, 3, 1, 4) # dilation rate of 2^2 self.d8 = CDilated(n, n, 3, 1, 8) # dilation rate of 2^3 self.br2 = nn.Sequential(nn.BatchNorm1d(nOut), nn.PReLU()) if nIn != nOut: # print(f'{nIn}-{nOut}: add=False') add = False self.add = add def forward(self, input): # reduce output1 = self.c1(input) output1 = self.br1(output1) # split and transform d1 = self.d1(output1) d2 = self.d2(output1) d4 = self.d4(output1) d8 = self.d8(output1) # heirarchical fusion for de-gridding add1 = d2 add2 = add1 + d4 add3 = add2 + d8 # merge combine = torch.cat([d1, add1, add2, add3], 1) # if residual version if self.add: combine = input + combine output = self.br2(combine) return output