File size: 5,866 Bytes
2ef3fe3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import random
import torch
import gradio as gr
import pandas as pd
from utils import create_vocab, setup_seed
from dataset_mlm import  get_paded_token_idx_gen, add_tokens_to_vocab

seed = random.randint(0,99999999)

setup_seed(seed)
device = torch.device("cpu")
vocab_mlm = create_vocab()
vocab_mlm = add_tokens_to_vocab(vocab_mlm)
save_path = 'mlm-model-27.pt'  #1
train_seqs = pd.read_csv('C0_seq.csv')  #2
train_seq = train_seqs['Seq'].tolist()
model = torch.load(save_path, weights_only=False, map_location=torch.device('cpu'))
model = model.to(device)

def temperature_sampling(logits, temperature):
    logits = logits / temperature
    probabilities = torch.softmax(logits, dim=-1)
    sampled_token = torch.multinomial(probabilities, 1)
    return sampled_token

def CTXGen(τ, g_num, start, end):
    X1 = "X"
    X2 = "X"
    X4 = ""
    X5 = ""
    X6 = ""
    model.eval()
    with torch.no_grad():
        new_seq = None
        generated_seqs = []
        generated_seqs_FINAL = []
        cls_pos_all = []
        cls_probability_all = []
        act_pos_all = []
        act_probability_all = []

        count = 0
        gen_num = int(g_num)
        NON_AA = ["B","O","U","Z","X",'<K16>', '<α1β1γδ>', '<Ca22>', '<AChBP>', '<K13>', '<α1BAR>', '<α1β1ε>', '<α1AAR>', '<GluN3A>', '<α4β2>',
                        '<GluN2B>', '<α75HT3>', '<Na14>', '<α7>', '<GluN2C>', '<NET>', '<NavBh>', '<α6β3β4>', '<Na11>', '<Ca13>', 
                        '<Ca12>', '<Na16>', '<α6α3β2>', '<GluN2A>', '<GluN2D>', '<K17>', '<α1β1δε>', '<GABA>', '<α9>', '<K12>', 
                        '<Kshaker>', '<α3β4>', '<Na18>', '<α3β2>', '<α6α3β2β3>', '<α1β1δ>', '<α6α3β4β3>', '<α2β2>','<α6β4>', '<α2β4>',
                        '<Na13>', '<Na12>', '<Na15>', '<α4β4>', '<α7α6β2>', '<α1β1γ>', '<NaTTXR>', '<K11>', '<Ca23>', 
                        '<α9α10>','<α6α3β4>', '<NaTTXS>', '<Na17>','<high>','<low>','[UNK]','[SEP]','[PAD]','[CLS]','[MASK]']

        while count < gen_num:
            gen_len = random.randint(int(start), int(end))
            X3 = "X" * gen_len
            seq = [f"{X1}|{X2}|{X3}|{X4}|{X5}|{X6}"]
            vocab_mlm.token_to_idx["X"] = 4

            padded_seq, _, _, _ = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
            input_text = ["[MASK]" if i=="X" else i for i in padded_seq]

            gen_length = len(input_text)
            length = gen_length - sum(1 for x in input_text if x != '[MASK]')

            for i in range(length):
                _, idx_seq, idx_msa, attn_idx = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
                idx_seq = torch.tensor(idx_seq).unsqueeze(0).to(device)
                idx_msa = torch.tensor(idx_msa).unsqueeze(0).to(device)
                attn_idx = torch.tensor(attn_idx).to(device)

                mask_positions = [j for j in range(gen_length) if input_text[j] == "[MASK]"]
                mask_position = torch.tensor([mask_positions[torch.randint(len(mask_positions), (1,))]])
                
                logits = model(idx_seq,idx_msa, attn_idx) 
                mask_logits = logits[0, mask_position.item(), :]

                predicted_token_id = temperature_sampling(mask_logits, τ)

                predicted_token = vocab_mlm.to_tokens(int(predicted_token_id))
                input_text[mask_position.item()] = predicted_token
                padded_seq[mask_position.item()] = predicted_token.strip()
                new_seq = padded_seq

            generated_seq = input_text
        
            generated_seq[1] = "[MASK]"
            generated_seq[2] = "[MASK]"
            input_ids = vocab_mlm.__getitem__(generated_seq)
            logits = model(torch.tensor([input_ids]).to(device), idx_msa)
            
            cls_mask_logits = logits[0, 1, :]
            act_mask_logits = logits[0, 2, :]
            
            cls_probability, cls_mask_probs = torch.topk((torch.softmax(cls_mask_logits, dim=-1)), k=1)
            act_probability, act_mask_probs = torch.topk((torch.softmax(act_mask_logits, dim=-1)), k=1)

            cls_pos = vocab_mlm.idx_to_token[cls_mask_probs[0].item()]
            act_pos = vocab_mlm.idx_to_token[act_mask_probs[0].item()]

            cls_probability = cls_probability[0].item()
            act_probability = act_probability[0].item()
            generated_seq = generated_seq[generated_seq.index('[MASK]') + 2:generated_seq.index('[SEP]')]
            if generated_seq.count('C') % 2 == 0 and len("".join(generated_seq)) == gen_len:
                generated_seqs.append("".join(generated_seq))
                if "".join(generated_seq) not in train_seq and "".join(generated_seq) not in generated_seqs[0:-1] and all(x not in NON_AA for x in generated_seq):
                    generated_seqs_FINAL.append("".join(generated_seq))
                    cls_pos_all.append(cls_pos)
                    cls_probability_all.append(cls_probability)
                    act_pos_all.append(act_pos)
                    act_probability_all.append(act_probability)
                    out = pd.DataFrame({'Generated_seq': generated_seqs_FINAL, 'Subtype': cls_pos_all, 'Subtype_probability': cls_probability_all, 'Potency': act_pos_all, 'Potency_probability': act_probability_all, 'random_seed': seed})
                    out.to_csv("output.csv", index=False)
                    count += 1
    return 'output.csv'

iface = gr.Interface(
    fn=CTXGen,
    inputs=[
        gr.Slider(minimum=1, maximum=2, step=0.01, label="τ"),
        gr.Dropdown(choices=[1,10,100,1000], label="Number of generations"),
        gr.Textbox(label="Min length"),
        gr.Textbox(label="Max length")
    ],
    outputs=["file"]
)
iface.launch()