AlienChen commited on
Commit
f900c86
·
verified ·
1 Parent(s): d732b63

Delete siamese

Browse files
Files changed (1) hide show
  1. siamese/siamese_ppi_decoy.py +0 -187
siamese/siamese_ppi_decoy.py DELETED
@@ -1,187 +0,0 @@
1
- import os
2
- import pdb
3
- import torch
4
- import torch.nn as nn
5
- import torch.optim as optim
6
- from torch.utils.data import Dataset, DataLoader
7
- from transformers import EsmModel, EsmTokenizer
8
- from sklearn.model_selection import train_test_split
9
- import pandas as pd
10
- from peft import BOFTConfig, get_peft_model
11
- from datasets import load_from_disk
12
- import time
13
-
14
- os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
15
-
16
- # Hyperparameters
17
- HYPERPARAMS = {
18
- 'learning_rate': 0.001,
19
- 'batch_size': 32,
20
- 'num_epochs': 10,
21
- # 'boft_block_size': 8,
22
- # 'boft_n_butterfly_factor': 1,
23
- # 'boft_dropout': 0.1,
24
- # 'boft_bias': 'boft_only',
25
- # 'boft_modules_to_save': [], # List any specific modules to save if needed
26
- # 'boft_target_modules': ["query", "value", "key", "output.dense", "mlp.fc1", "mlp.fc2"],
27
- 'margin': 1.0
28
- }
29
-
30
-
31
- # Siamese NN
32
- class SiameseNetwork(nn.Module):
33
- def __init__(self, encoder):
34
- super(SiameseNetwork, self).__init__()
35
- self.encoder = encoder
36
- self.embedding_dim = encoder.config.hidden_size
37
- self.projection = nn.Linear(self.embedding_dim * 2, self.embedding_dim)
38
-
39
- def forward(self, target_tokens, binder_tokens, decoy_tokens):
40
- target_embedding = self.encoder(**target_tokens).last_hidden_state[:, 0, :]
41
- binder_embedding = self.encoder(**binder_tokens).last_hidden_state[:, 0, :]
42
- decoy_embedding = self.encoder(**decoy_tokens).last_hidden_state[:, 0, :]
43
-
44
- # Compute joint embeddings
45
- anchor_embedding = torch.cat((target_embedding, binder_embedding), dim=-1)
46
- positive_embedding = torch.cat((binder_embedding, target_embedding), dim=-1)
47
- negative_embedding = torch.cat((decoy_embedding, binder_embedding), dim=-1)
48
-
49
- # Project joint embeddings back to original dimensions
50
- anchor_embedding = self.projection(anchor_embedding)
51
- positive_embedding = self.projection(positive_embedding)
52
- negative_embedding = self.projection(negative_embedding)
53
-
54
- return anchor_embedding, positive_embedding, negative_embedding
55
-
56
-
57
- # Generate scores for candidate binders
58
- def generate_scores(siamese_net, tokenizer, target_seq, candidate_binders, decoy_seq):
59
- siamese_net.eval()
60
- scores = []
61
-
62
- with torch.no_grad():
63
- target_tokens = tokenizer(target_seq, return_tensors="pt", padding=True, truncation=True).to(device)
64
- decoy_tokens = tokenizer(decoy_seq, return_tensors="pt", padding=True, truncation=True).to(device)
65
-
66
- for binder_seq in candidate_binders:
67
- binder_tokens = tokenizer(binder_seq, return_tensors="pt", padding=True, truncation=True).to(device)
68
- target_embedding, binder_embedding, decoy_embedding = siamese_net(target_tokens, binder_tokens, decoy_tokens)
69
- target_binder_similarity = torch.cosine_similarity(target_embedding, binder_embedding)
70
- target_decoy_similarity = torch.cosine_similarity(target_embedding, decoy_embedding)
71
- score = target_binder_similarity - target_decoy_similarity
72
- scores.append(score.item())
73
-
74
- return scores
75
-
76
-
77
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
78
- distributed = torch.cuda.device_count() > 1
79
-
80
-
81
- # Load the pre-trained ESM-2-650M model and tokenizer
82
- model_name = "facebook/esm2_t33_650M_UR50D"
83
- tokenizer = EsmTokenizer.from_pretrained(model_name)
84
- model = EsmModel.from_pretrained(model_name)
85
-
86
- siamese_ppi_net = SiameseNetwork(model).to(device)
87
- if distributed:
88
- siamese_ppi_net = torch.nn.DataParallel(siamese_ppi_net)
89
-
90
-
91
- # Define the triplet loss function
92
- criterion = nn.TripletMarginLoss(margin=HYPERPARAMS['margin']).to(device)
93
-
94
- # Define the optimizer
95
- optimizer = optim.Adam(siamese_ppi_net.parameters(), lr=HYPERPARAMS['learning_rate'])
96
-
97
- # Load dataset
98
- train_dataset = load_from_disk('/home/tc415/muPPIt/dataset/train_mut')
99
- val_dataset = load_from_disk('/home/tc415/muPPIt/dataset/val_mut')
100
- test_dataset = load_from_disk('/home/tc415/muPPIt/dataset/test_mut')
101
-
102
- # Training loop
103
- for epoch in range(HYPERPARAMS['num_epochs']):
104
- # Training
105
- siamese_ppi_net.train()
106
- train_loss = 0.0
107
- # for target_tokens, binder_tokens, decoy_tokens in train_dataloader:
108
- for batch in train_dataset:
109
- # pdb.set_trace()
110
- start = time.time()
111
- target_tokens = {'input_ids': torch.tensor(batch['anchor_input_ids']).to(device),
112
- 'attention_mask': torch.tensor(batch['anchor_attention_mask']).to(device)}
113
- binder_tokens = {'input_ids': torch.tensor(batch['positive_input_ids']).to(device),
114
- 'attention_mask': torch.tensor(batch['positive_attention_mask']).to(device)}
115
- decoy_tokens = {'input_ids': torch.tensor(batch['negative_input_ids']).to(device),
116
- 'attention_mask': torch.tensor(batch['negative_attention_mask']).to(device)}
117
-
118
-
119
- # pdb.set_trace()
120
- # Forward pass
121
- target_embedding, binder_embedding, decoy_embedding = siamese_ppi_net(target_tokens, binder_tokens, decoy_tokens)
122
-
123
- # Compute the triplet loss
124
- loss = criterion(target_embedding, binder_embedding, decoy_embedding)
125
-
126
- # Backward pass and optimization
127
- optimizer.zero_grad()
128
- loss.backward()
129
- optimizer.step()
130
- train_loss += loss.item()
131
-
132
- print(f"loss = {loss.item()}, time = {time.time()-start}s")
133
-
134
- train_loss /= len(train_dataset)
135
-
136
- # Validation
137
- siamese_ppi_net.eval()
138
- val_loss = 0.0
139
- with torch.no_grad():
140
- for batch in val_dataset:
141
- target_tokens = {'input_ids': torch.tensor(batch['anchor_input_ids']).to(device),
142
- 'attention_mask': torch.tensor(batch['anchor_attention_mask']).to(device)}
143
- binder_tokens = {'input_ids': torch.tensor(batch['positive_input_ids']).to(device),
144
- 'attention_mask': torch.tensor(batch['positive_attention_mask']).to(device)}
145
- decoy_tokens = {'input_ids': torch.tensor(batch['negative_input_ids']).to(device),
146
- 'attention_mask': torch.tensor(batch['negative_attention_mask']).to(device)}
147
- target_embedding, binder_embedding, decoy_embedding = siamese_ppi_net(target_tokens, binder_tokens, decoy_tokens)
148
- loss = criterion(target_embedding, binder_embedding, decoy_embedding)
149
- val_loss += loss.item()
150
- val_loss /= len(val_dataset)
151
-
152
- print(f"Epoch [{epoch+1}/{HYPERPARAMS['num_epochs']}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
153
-
154
- # Testing
155
- siamese_ppi_net.eval()
156
- test_loss = 0.0
157
- with torch.no_grad():
158
- for batch in test_dataset:
159
- target_tokens = {'input_ids': torch.tensor(batch['anchor_input_ids']).to(device),
160
- 'attention_mask': torch.tensor(batch['anchor_attention_mask']).to(device)}
161
- binder_tokens = {'input_ids': torch.tensor(batch['positive_input_ids']).to(device),
162
- 'attention_mask': torch.tensor(batch['positive_attention_mask']).to(device)}
163
- decoy_tokens = {'input_ids': torch.tensor(batch['negative_input_ids']).to(device),
164
- 'attention_mask': torch.tensor(batch['negative_attention_mask']).to(device)}
165
- target_embedding, binder_embedding, decoy_embedding = siamese_ppi_net(target_tokens, binder_tokens, decoy_tokens)
166
- loss = criterion(target_embedding, binder_embedding, decoy_embedding)
167
- test_loss += loss.item()
168
- test_loss /= len(test_dataset)
169
-
170
- print(f"Test Loss: {test_loss:.4f}")
171
-
172
- # Save the trained model
173
- torch.save(siamese_ppi_net.state_dict(), "siamese_ppi_model.pth")
174
-
175
- # # Example: Scoring for candidate binders
176
- # target_seq = "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"
177
- # candidate_binders = [
178
- # "KTVNELEKVIKKQGKRAKLIIAIIMIIIIIIVV",
179
- # "ATVRELEKQIKKQRKRAKLIIAIVMIFIIVVVV",
180
- # "KTVNELEKQIKKQGKRAKLIIAIVMIIIIVVVV"
181
- # ]
182
- # decoy_seq = "MHIKPLLSRLAQAAANASATPPPPPPPPPGPAVAEEPLHRPTNPGASSGCHKQPLKQSDCPKRPR"
183
-
184
- # scores = generate_scores(siamese_ppi_net, tokenizer, target_seq, candidate_binders, decoy_seq)
185
- # print("Candidate Binder Scores:")
186
- # for binder, score in zip(candidate_binders, scores):
187
- # print(f"Binder: {binder}, Score: {score:.4f}")