MoraxCheng commited on
Commit
7150117
·
verified ·
1 Parent(s): 09e6e98

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +381 -0
app.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Tranception Design App - Hugging Face Spaces Version
4
+ """
5
+ import os
6
+ import sys
7
+ import torch
8
+ import transformers
9
+ from transformers import PreTrainedTokenizerFast
10
+ import numpy as np
11
+ import pandas as pd
12
+ import matplotlib.pyplot as plt
13
+ import seaborn as sns
14
+ import gradio as gr
15
+ from huggingface_hub import hf_hub_download
16
+ import zipfile
17
+ import shutil
18
+
19
+ # Add current directory to path
20
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
21
+
22
+ # Check if we need to download and extract the tranception module
23
+ if not os.path.exists("tranception"):
24
+ print("Downloading Tranception repository...")
25
+ # Clone the repository structure
26
+ os.system("git clone https://github.com/OATML-Markslab/Tranception.git temp_tranception")
27
+ # Move the tranception module to current directory
28
+ shutil.move("temp_tranception/tranception", "tranception")
29
+ # Clean up
30
+ shutil.rmtree("temp_tranception")
31
+
32
+ import tranception
33
+ from tranception import config, model_pytorch
34
+
35
+ # Download model checkpoints if not present
36
+ def download_model_from_hf(model_name):
37
+ """Download model from Hugging Face Hub if not present locally"""
38
+ model_path = f"./{model_name}"
39
+ if not os.path.exists(model_path):
40
+ print(f"Downloading {model_name} model...")
41
+ try:
42
+ # For Small and Medium models, they are available on HF Hub
43
+ if model_name in ["Tranception_Small", "Tranception_Medium"]:
44
+ return f"PascalNotin/{model_name}"
45
+ else:
46
+ # For Large model, we need to download from the original source
47
+ print("Note: Large model needs to be downloaded from the original source.")
48
+ print("Using Medium model as fallback...")
49
+ return "PascalNotin/Tranception_Medium"
50
+ except Exception as e:
51
+ print(f"Error downloading {model_name}: {e}")
52
+ return None
53
+ return model_path
54
+
55
+ AA_vocab = "ACDEFGHIKLMNPQRSTVWY"
56
+ tokenizer = PreTrainedTokenizerFast(tokenizer_file="./tranception/utils/tokenizers/Basic_tokenizer",
57
+ unk_token="[UNK]",
58
+ sep_token="[SEP]",
59
+ pad_token="[PAD]",
60
+ cls_token="[CLS]",
61
+ mask_token="[MASK]"
62
+ )
63
+
64
+ def create_all_single_mutants(sequence,AA_vocab=AA_vocab,mutation_range_start=None,mutation_range_end=None):
65
+ all_single_mutants={}
66
+ sequence_list=list(sequence)
67
+ if mutation_range_start is None: mutation_range_start=1
68
+ if mutation_range_end is None: mutation_range_end=len(sequence)
69
+ for position,current_AA in enumerate(sequence[mutation_range_start-1:mutation_range_end]):
70
+ for mutated_AA in AA_vocab:
71
+ if current_AA!=mutated_AA:
72
+ mutated_sequence = sequence_list.copy()
73
+ mutated_sequence[mutation_range_start + position - 1] = mutated_AA
74
+ all_single_mutants[current_AA+str(mutation_range_start+position)+mutated_AA]="".join(mutated_sequence)
75
+ all_single_mutants = pd.DataFrame.from_dict(all_single_mutants,columns=['mutated_sequence'],orient='index')
76
+ all_single_mutants.reset_index(inplace=True)
77
+ all_single_mutants.columns = ['mutant','mutated_sequence']
78
+ return all_single_mutants
79
+
80
+ def create_scoring_matrix_visual(scores,sequence,image_index=0,mutation_range_start=None,mutation_range_end=None,AA_vocab=AA_vocab,annotate=True,fontsize=20):
81
+ filtered_scores=scores.copy()
82
+ filtered_scores=filtered_scores[filtered_scores.position.isin(range(mutation_range_start,mutation_range_end+1))]
83
+ piv=filtered_scores.pivot(index='position',columns='target_AA',values='avg_score').round(4)
84
+
85
+ # Save CSV file
86
+ csv_path = 'fitness_scoring_substitution_matrix_{}.csv'.format(image_index)
87
+
88
+ # Create a more detailed CSV with mutation info
89
+ csv_data = []
90
+ for position in range(mutation_range_start,mutation_range_end+1):
91
+ for target_AA in list(AA_vocab):
92
+ mutant = sequence[position-1]+str(position)+target_AA
93
+ if mutant in set(filtered_scores.mutant):
94
+ score_value = filtered_scores.loc[filtered_scores.mutant==mutant,'avg_score']
95
+ if isinstance(score_value, pd.Series):
96
+ score = float(score_value.iloc[0])
97
+ else:
98
+ score = float(score_value)
99
+ else:
100
+ score = 0.0
101
+
102
+ csv_data.append({
103
+ 'position': position,
104
+ 'original_AA': sequence[position-1],
105
+ 'target_AA': target_AA,
106
+ 'mutation': mutant,
107
+ 'fitness_score': score
108
+ })
109
+
110
+ csv_df = pd.DataFrame(csv_data)
111
+ csv_df.to_csv(csv_path, index=False)
112
+
113
+ # Continue with visualization
114
+ mutation_range_len = mutation_range_end - mutation_range_start + 1
115
+ fig, ax = plt.subplots(figsize=(50,mutation_range_len))
116
+ scores_dict = {}
117
+ valid_mutant_set=set(filtered_scores.mutant)
118
+ ax.tick_params(bottom=True, top=True, left=True, right=True)
119
+ ax.tick_params(labelbottom=True, labeltop=True, labelleft=True, labelright=True)
120
+ if annotate:
121
+ for position in range(mutation_range_start,mutation_range_end+1):
122
+ for target_AA in list(AA_vocab):
123
+ mutant = sequence[position-1]+str(position)+target_AA
124
+ if mutant in valid_mutant_set:
125
+ score_value = filtered_scores.loc[filtered_scores.mutant==mutant,'avg_score']
126
+ if isinstance(score_value, pd.Series):
127
+ scores_dict[mutant] = float(score_value.iloc[0])
128
+ else:
129
+ scores_dict[mutant] = float(score_value)
130
+ else:
131
+ scores_dict[mutant]=0.0
132
+ labels = (np.asarray(["{} \n {:.4f}".format(symb,value) for symb, value in scores_dict.items() ])).reshape(mutation_range_len,len(AA_vocab))
133
+ heat = sns.heatmap(piv,annot=labels,fmt="",cmap='RdYlGn',linewidths=0.30,ax=ax,vmin=np.percentile(scores.avg_score,2),vmax=np.percentile(scores.avg_score,98),\
134
+ cbar_kws={'label': 'Log likelihood ratio (mutant / starting sequence)'},annot_kws={"size": fontsize})
135
+ else:
136
+ heat = sns.heatmap(piv,cmap='RdYlGn',linewidths=0.30,ax=ax,vmin=np.percentile(scores.avg_score,2),vmax=np.percentile(scores.avg_score,98),\
137
+ cbar_kws={'label': 'Log likelihood ratio (mutant / starting sequence)'},annot_kws={"size": fontsize})
138
+ heat.figure.axes[-1].yaxis.label.set_size(fontsize=int(fontsize*1.5))
139
+ heat.set_title("Higher predicted scores (green) imply higher protein fitness",fontsize=fontsize*2, pad=40)
140
+ heat.set_ylabel("Sequence position", fontsize = fontsize*2)
141
+ heat.set_xlabel("Amino Acid mutation", fontsize = fontsize*2)
142
+
143
+ # Set y-axis labels (positions)
144
+ yticklabels = [str(pos)+' ('+sequence[pos-1]+')' for pos in range(mutation_range_start,mutation_range_end+1)]
145
+ heat.set_yticklabels(yticklabels, fontsize=fontsize, rotation=0)
146
+
147
+ # Set x-axis labels (amino acids) - ensuring correct number
148
+ heat.set_xticklabels(list(AA_vocab), fontsize=fontsize)
149
+ plt.tight_layout()
150
+ image_path = 'fitness_scoring_substitution_matrix_{}.png'.format(image_index)
151
+ plt.savefig(image_path,dpi=100)
152
+ plt.close()
153
+ return image_path, csv_path
154
+
155
+ def suggest_mutations(scores):
156
+ intro_message = "The following mutations may be sensible options to improve fitness: \n\n"
157
+ #Best mutants
158
+ top_mutants=list(scores.sort_values(by=['avg_score'],ascending=False).head(5).mutant)
159
+ top_mutants_fitness=list(scores.sort_values(by=['avg_score'],ascending=False).head(5).avg_score)
160
+ top_mutants_recos = [top_mutant+" ("+str(round(top_mutant_fitness,4))+")" for (top_mutant,top_mutant_fitness) in zip(top_mutants,top_mutants_fitness)]
161
+ mutant_recos = "The single mutants with highest predicted fitness are (positive scores indicate fitness increase Vs starting sequence, negative scores indicate fitness decrease):\n {} \n\n".format(", ".join(top_mutants_recos))
162
+ #Best positions
163
+ positive_scores = scores[scores.avg_score > 0]
164
+ if len(positive_scores) > 0:
165
+ # Only select numeric columns for groupby mean
166
+ positive_scores_position_avg = positive_scores.groupby(['position'])['avg_score'].mean().reset_index()
167
+ top_positions=list(positive_scores_position_avg.sort_values(by=['avg_score'],ascending=False).head(5)['position'].astype(str))
168
+ position_recos = "The positions with the highest average fitness increase are (only positions with at least one fitness increase are considered):\n {}".format(", ".join(top_positions))
169
+ else:
170
+ position_recos = "No positions with positive fitness effects found."
171
+ return intro_message+mutant_recos+position_recos
172
+
173
+ def check_valid_mutant(sequence,mutant,AA_vocab=AA_vocab):
174
+ valid = True
175
+ try:
176
+ from_AA, position, to_AA = mutant[0], int(mutant[1:-1]), mutant[-1]
177
+ except:
178
+ valid = False
179
+ if valid and position > 0 and position <= len(sequence):
180
+ if sequence[position-1]!=from_AA: valid=False
181
+ else:
182
+ valid = False
183
+ if to_AA not in AA_vocab: valid=False
184
+ return valid
185
+
186
+ def get_mutated_protein(sequence,mutant):
187
+ if not check_valid_mutant(sequence,mutant):
188
+ return "The mutant is not valid"
189
+ mutated_sequence = list(sequence)
190
+ mutated_sequence[int(mutant[1:-1])-1]=mutant[-1]
191
+ return ''.join(mutated_sequence)
192
+
193
+ def score_and_create_matrix_all_singles(sequence,mutation_range_start=None,mutation_range_end=None,model_type="Small",scoring_mirror=False,batch_size_inference=20,max_number_positions_per_heatmap=50,num_workers=0,AA_vocab=AA_vocab):
194
+ if mutation_range_start is None: mutation_range_start=1
195
+ if mutation_range_end is None: mutation_range_end=len(sequence)
196
+
197
+ # Clean sequence
198
+ sequence = sequence.strip().upper()
199
+
200
+ # Validate
201
+ assert len(sequence) > 0, "no sequence entered"
202
+ assert mutation_range_start <= mutation_range_end, "mutation range is invalid"
203
+ assert mutation_range_end <= len(sequence), f"End position ({mutation_range_end}) exceeds sequence length ({len(sequence)})"
204
+
205
+ # Load model with HF Space compatibility
206
+ if model_type=="Small":
207
+ model_path = download_model_from_hf("Tranception_Small")
208
+ model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path=model_path)
209
+ elif model_type=="Medium":
210
+ model_path = download_model_from_hf("Tranception_Medium")
211
+ model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path=model_path)
212
+ elif model_type=="Large":
213
+ # For HF Spaces, we recommend using Medium model due to memory constraints
214
+ print("Note: Large model requires significant memory. Using Medium model for HF Spaces deployment.")
215
+ model_path = download_model_from_hf("Tranception_Medium")
216
+ model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path=model_path)
217
+
218
+ # Device selection - for HF Spaces, typically CPU
219
+ if torch.cuda.is_available():
220
+ device = torch.device("cuda")
221
+ model.cuda()
222
+ print("Inference will take place on NVIDIA GPU")
223
+ else:
224
+ device = torch.device("cpu")
225
+ model.to(device)
226
+ print("Inference will take place on CPU")
227
+ # Reduce batch size for CPU inference
228
+ batch_size_inference = min(batch_size_inference, 10)
229
+
230
+ model.eval()
231
+ model.config.tokenizer = tokenizer
232
+
233
+ all_single_mutants = create_all_single_mutants(sequence,AA_vocab,mutation_range_start,mutation_range_end)
234
+
235
+ with torch.no_grad():
236
+ scores = model.score_mutants(DMS_data=all_single_mutants,
237
+ target_seq=sequence,
238
+ scoring_mirror=scoring_mirror,
239
+ batch_size_inference=batch_size_inference,
240
+ num_workers=num_workers,
241
+ indel_mode=False
242
+ )
243
+
244
+ scores = pd.merge(scores,all_single_mutants,on="mutated_sequence",how="left")
245
+ scores["position"]=scores["mutant"].map(lambda x: int(x[1:-1]))
246
+ scores["target_AA"] = scores["mutant"].map(lambda x: x[-1])
247
+
248
+ score_heatmaps = []
249
+ csv_files = []
250
+ mutation_range = mutation_range_end - mutation_range_start + 1
251
+ number_heatmaps = int((mutation_range - 1) / max_number_positions_per_heatmap) + 1
252
+ image_index = 0
253
+ window_start = mutation_range_start
254
+ window_end = min(mutation_range_end,mutation_range_start+max_number_positions_per_heatmap-1)
255
+
256
+ for image_index in range(number_heatmaps):
257
+ image_path, csv_path = create_scoring_matrix_visual(scores,sequence,image_index,window_start,window_end,AA_vocab)
258
+ score_heatmaps.append(image_path)
259
+ csv_files.append(csv_path)
260
+ window_start += max_number_positions_per_heatmap
261
+ window_end = min(mutation_range_end,window_start+max_number_positions_per_heatmap-1)
262
+
263
+ # Also save a comprehensive CSV with all mutations
264
+ comprehensive_csv_path = 'all_mutations_fitness_scores.csv'
265
+ scores_export = scores[['mutant', 'position', 'target_AA', 'avg_score', 'mutated_sequence']].copy()
266
+ scores_export['original_AA'] = scores_export['mutant'].str[0]
267
+ scores_export = scores_export.rename(columns={'avg_score': 'fitness_score'})
268
+ scores_export = scores_export[['position', 'original_AA', 'target_AA', 'mutant', 'fitness_score', 'mutated_sequence']]
269
+ scores_export.to_csv(comprehensive_csv_path, index=False)
270
+ csv_files.append(comprehensive_csv_path)
271
+
272
+ return score_heatmaps, suggest_mutations(scores), csv_files
273
+
274
+ def extract_sequence(example):
275
+ label, taxon, sequence = example
276
+ return sequence
277
+
278
+ def clear_inputs(protein_sequence_input,mutation_range_start,mutation_range_end):
279
+ protein_sequence_input = ""
280
+ mutation_range_start = None
281
+ mutation_range_end = None
282
+ return protein_sequence_input,mutation_range_start,mutation_range_end
283
+
284
+ # Create Gradio app
285
+ tranception_design = gr.Blocks()
286
+
287
+ with tranception_design:
288
+ gr.Markdown("# In silico directed evolution for protein redesign with Tranception")
289
+ gr.Markdown("## 🧬 Hugging Face Spaces Demo")
290
+ gr.Markdown("Perform in silico directed evolution with Tranception to iteratively improve the fitness of a protein of interest, one mutation at a time. At each step, the Tranception model computes the log likelihood ratios of all possible single amino acid substitution Vs the starting sequence, and outputs a fitness heatmap and recommandations to guide the selection of the mutation to apply.")
291
+ gr.Markdown("**Note**: This demo runs on CPU in Hugging Face Spaces. For faster inference, consider using GPU locally or selecting the Small model.")
292
+
293
+ with gr.Tabs():
294
+ with gr.TabItem("Input"):
295
+ with gr.Row():
296
+ protein_sequence_input = gr.Textbox(lines=1,
297
+ label="Protein sequence",
298
+ placeholder = "Input the sequence of amino acids representing the starting protein of interest or select one from the list of examples below. You may enter the full sequence or just a subdomain (providing full context typically leads to better results, but is slower at inference)"
299
+ )
300
+
301
+ with gr.Row():
302
+ mutation_range_start = gr.Number(label="Start of mutation window (first position indexed at 1)", value=1, precision=0)
303
+ mutation_range_end = gr.Number(label="End of mutation window (leave empty for full lenth)", value=10, precision=0)
304
+
305
+ with gr.TabItem("Parameters"):
306
+ with gr.Row():
307
+ model_size_selection = gr.Radio(label="Tranception model size (larger models are more accurate but are slower at inference)",
308
+ choices=["Small","Medium","Large"],
309
+ value="Small")
310
+ with gr.Row():
311
+ scoring_mirror = gr.Checkbox(label="Score protein from both directions (leads to more robust fitness predictions, but doubles inference time)")
312
+ with gr.Row():
313
+ batch_size_inference = gr.Number(label="Model batch size at inference time (reduce for CPU)",value = 10, precision=0)
314
+ with gr.Row():
315
+ gr.Markdown("Note: the current version does not leverage retrieval of homologs at inference time to increase fitness prediction performance.")
316
+
317
+ with gr.Row():
318
+ clear_button = gr.Button(value="Clear",variant="secondary")
319
+ run_button = gr.Button(value="Predict fitness",variant="primary")
320
+
321
+ protein_ID = gr.Textbox(label="Uniprot ID", visible=False)
322
+ taxon = gr.Textbox(label="Taxon", visible=False)
323
+
324
+ examples = gr.Examples(
325
+ inputs=[protein_ID, taxon, protein_sequence_input],
326
+ outputs=[protein_sequence_input],
327
+ fn=extract_sequence,
328
+ examples=[
329
+ ['ADRB2_HUMAN' ,'Human', 'MGQPGNGSAFLLAPNGSHAPDHDVTQERDEVWVVGMGIVMSLIVLAIVFGNVLVITAIAKFERLQTVTNYFITSLACADLVMGLAVVPFGAAHILMKMWTFGNFWCEFWTSIDVLCVTASIETLCVIAVDRYFAITSPFKYQSLLTKNKARVIILMVWIVSGLTSFLPIQMHWYRATHQEAINCYANETCCDFFTNQAYAIASSIVSFYVPLVIMVFVYSRVFQEAKRQLQKIDKSEGRFHVQNLSQVEQDGRTGHGLRRSSKFCLKEHKALKTLGIIMGTFTLCWLPFFIVNIVHVIQDNLIRKEVYILLNWIGYVNSGFNPLIYCRSPDFRIAFQELLCLRRSSLKAYGNGYSSNGNTGEQSGYHVEQEKENKLLCEDLPGTEDFVGHQGTVPSDNIDSQGRNCSTNDSLL'],
330
+ ['IF1_ECOLI' ,'Prokaryote', 'MAKEDNIEMQGTVLETLPNTMFRVELENGHVVTAHISGKMRKNYIRILTGDKVTVELTPYDLSKGRIVFRSR'],
331
+ ['P53_HUMAN' ,'Human', 'MEEPQSDPSVEPPLSQETFSDLWKLLPENNVLSPLPSQAMDDLMLSPDDIEQWFTEDPGPDEAPRMPEAAPRVAPAPAAPTPAAPAPAPSWPLSSSVPSQKTYQGSYGFRLGFLHSGTAKSVTCTYSPALNKMFCQLAKTCPVQLWVDSTPPPGTRVRAMAIYKQSQHMTEVVRRCPHHERCSDSDGLAPPQHLIRVEGNLRVEYLDDRNTFRHSVVVPYEPPEVGSDCTTIHYNYMCNSSCMGGMNRRPILTIITLEDSSGNLLGRNSFEVRVCACPGRDRRTEEENLRKKGEPHHELPPGSTKRALPNNTSSSPQPKKKPLDGEYFTLQIRGRERFEMFRELNEALELKDAQAGKEPGGSRAHSSHLKSKKGQSTSRHKKLMFKTEGPDSD'],
332
+ ['BLAT_ECOLX' ,'Prokaryote', 'MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIELDLNSGKILESFRPEERFPMMSTFKVLLCGAVLSRVDAGQEQLGRRIHYSQNDLVEYSPVTEKHLTDGMTVRELCSAAITMSDNTAANLLLTTIGGPKELTAFLHNMGDHVTRLDRWEPELNEAIPNDERDTTMPAAMATTLRKLLTGELLTLASRQQLIDWMEADKVAGPLLRSALPAGWFIADKSGAGERGSRGIIAALGPDGKPSRIVVIYTTGSQATMDERNRQIAEIGASLIKHW'],
333
+ ['BRCA1_HUMAN' ,'Human', 'MDLSALRVEEVQNVINAMQKILECPICLELIKEPVSTKCDHIFCKFCMLKLLNQKKGPSQCPLCKNDITKRSLQESTRFSQLVEELLKIICAFQLDTGLEYANSYNFAKKENNSPEHLKDEVSIIQSMGYRNRAKRLLQSEPENPSLQETSLSVQLSNLGTVRTLRTKQRIQPQKTSVYIELGSDSSEDTVNKATYCSVGDQELLQITPQGTRDEISLDSAKKAACEFSETDVTNTEHHQPSNNDLNTTEKRAAERHPEKYQGSSVSNLHVEPCGTNTHASSLQHENSSLLLTKDRMNVEKAEFCNKSKQPGLARSQHNRWAGSKETCNDRRTPSTEKKVDLNADPLCERKEWNKQKLPCSENPRDTEDVPWITLNSSIQKVNEWFSRSDELLGSDDSHDGESESNAKVADVLDVLNEVDEYSGSSEKIDLLASDPHEALICKSERVHSKSVESNIEDKIFGKTYRKKASLPNLSHVTENLIIGAFVTEPQIIQERPLTNKLKRKRRPTSGLHPEDFIKKADLAVQKTPEMINQGTNQTEQNGQVMNITNSGHENKTKGDSIQNEKNPNPIESLEKESAFKTKAEPISSSISNMELELNIHNSKAPKKNRLRRKSSTRHIHALELVVSRNLSPPNCTELQIDSCSSSEEIKKKKYNQMPVRHSRNLQLMEGKEPATGAKKSNKPNEQTSKRHDSDTFPELKLTNAPGSFTKCSNTSELKEFVNPSLPREEKEEKLETVKVSNNAEDPKDLMLSGERVLQTERSVESSSISLVPGTDYGTQESISLLEVSTLGKAKTEPNKCVSQCAAFENPKGLIHGCSKDNRNDTEGFKYPLGHEVNHSRETSIEMEESELDAQYLQNTFKVSKRQSFAPFSNPGNAEEECATFSAHSGSLKKQSPKVTFECEQKEENQGKNESNIKPVQTVNITAGFPVVGQKDKPVDNAKCSIKGGSRFCLSSQFRGNETGLITPNKHGLLQNPYRIPPLFPIKSFVKTKCKKNLLEENFEEHSMSPEREMGNENIPSTVSTISRNNIRENVFKEASSSNINEVGSSTNEVGSSINEIGSSDENIQAELGRNRGPKLNAMLRLGVLQPEVYKQSLPGSNCKHPEIKKQEYEEVVQTVNTDFSPYLISDNLEQPMGSSHASQVCSETPDDLLDDGEIKEDTSFAENDIKESSAVFSKSVQKGELSRSPSPFTHTHLAQGYRRGAKKLESSEENLSSEDEELPCFQHLLFGKVNNIPSQSTRHSTVATECLSKNTEENLLSLKNSLNDCSNQVILAKASQEHHLSEETKCSASLFSSQCSELEDLTANTNTQDPFLIGSSKQMRHQSESQGVGLSDKELVSDDEERGTGLEENNQEEQSMDSNLGEAASGCESETSVSEDCSGLSSQSDILTTQQRDTMQHNLIKLQQEMAELEAVLEQHGSQPSNSYPSIISDSSALEDLRNPEQSTSEKAVLTSQKSSEYPISQNPEGLSADKFEVSADSSTSKNKEPGVERSSPSKCPSLDDRWYMHSCSGSLQNRNYPSQEELIKVVDVEEQQLEESGPHDLTETSYLPRQDLEGTPYLESGISLFSDDPESDPSEDRAPESARVGNIPSSTSALKVPQLKVAESAQSPAAAHTTDTAGYNAMEESVSREKPELTASTERVNKRMSMVVSGLTPEEFMLVYKFARKHHITLTNLITEETTHVVMKTDAEFVCERTLKYFLGIAGGKWVVSYFWVTQSIKERKMLNEHDFEVRGDVVNGRNHQGPKRARESQDRKIFRGLEICCYGPFTNMPTDQLEWMVQLCGASVVKELSSFTLGTGVHPIVVVQPDAWTEDNGFHAIGQMCEAPVVTREWVLDSVALYQCQELDTYLIPQIPHSHY'],
334
+ ['CALM1_HUMAN' ,'Human', 'MADQLTEEQIAEFKEAFSLFDKDGDGTITTKELGTVMRSLGQNPTEAELQDMINEVDADGNGTIDFPEFLTMMARKMKDTDSEEEIREAFRVFDKDGNGYISAAELRHVMTNLGEKLTDEEVDEMIREADIDGDGQVNYEEFVQMMTAK'],
335
+ ['CCDB_ECOLI' ,'Prokaryote', 'MQFKVYTYKRESRYRLFVDVQSDIIDTPGRRMVIPLASARLLSDKVSRELYPVVHIGDESWRMMTTDMASVPVSVIGEEVADLSHRENDIKNAINLMFWGI'],
336
+ ['GFP_AEQVI' ,'Other eukaryote', 'MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTLSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK'],
337
+ ['GRB2_HUMAN' ,'Human', 'MEAIAKYDFKATADDELSFKRGDILKVLNEECDQNWYKAELNGKDGFIPKNYIEMKPHPWFFGKIPRAKAEEMLSKQRHDGAFLIRESESAPGDFSLSVKFGNDVQHFKVLRDGAGKYFLWVVKFNSLNELVDYHRSTSVSRNQQIFLRDIEQVPQQPTYVQALFDFDPQEDGELGFRRGDFIHVMDNSDPNWWKGACHGQTGMFPRNYVTPVNRNV'],
338
+ ],
339
+ )
340
+
341
+ gr.Markdown("<br>")
342
+ gr.Markdown("# Fitness predictions for all single amino acid substitutions in mutation range")
343
+ gr.Markdown("Inference may take a few seconds for short proteins & mutation ranges to several minutes for longer ones")
344
+ output_image = gr.Gallery(label="Fitness predictions for all single amino acid substitutions in mutation range") #Using Gallery to break down large scoring matrices into smaller images
345
+
346
+ output_recommendations = gr.Textbox(label="Mutation recommendations")
347
+
348
+ with gr.Row():
349
+ gr.Markdown("## Download CSV Files")
350
+ output_csv_files = gr.File(label="Download CSV files with fitness scores", file_count="multiple", interactive=False)
351
+
352
+ clear_button.click(
353
+ inputs = [protein_sequence_input,mutation_range_start,mutation_range_end],
354
+ outputs = [protein_sequence_input,mutation_range_start,mutation_range_end],
355
+ fn=clear_inputs
356
+ )
357
+ run_button.click(
358
+ fn=score_and_create_matrix_all_singles,
359
+ inputs=[protein_sequence_input,mutation_range_start,mutation_range_end,model_size_selection,scoring_mirror,batch_size_inference],
360
+ outputs=[output_image,output_recommendations,output_csv_files],
361
+ )
362
+
363
+ gr.Markdown("# Mutate the starting protein sequence")
364
+ with gr.Row():
365
+ mutation_triplet = gr.Textbox(lines=1,label="Selected mutation", placeholder = "Input the mutation triplet for the selected mutation (eg., M1A)")
366
+ mutate_button = gr.Button(value="Apply mutation to starting protein", variant="primary")
367
+ mutated_protein_sequence = gr.Textbox(lines=1,label="Mutated protein sequence")
368
+ mutate_button.click(
369
+ fn = get_mutated_protein,
370
+ inputs = [protein_sequence_input,mutation_triplet],
371
+ outputs = mutated_protein_sequence
372
+ )
373
+
374
+ gr.Markdown("<p>You may now use the output mutated sequence above as the starting sequence for another round of in silico directed evolution.</p>")
375
+ gr.Markdown("For more information about the Tranception model, please refer to our paper below:")
376
+ gr.Markdown("<p><b>Tranception: Protein Fitness Prediction with Autoregressive Transformers and Inference-time Retrieval</b><br>Pascal Notin, Mafalda Dias, Jonathan Frazer, Javier Marchena-Hurtado, Aidan N. Gomez, Debora S. Marks<sup>*</sup>, Yarin Gal<sup>*</sup><br><sup>* equal senior authorship</sup></p>")
377
+ gr.Markdown("Links: <a href='https://proceedings.mlr.press/v162/notin22a.html' target='_blank'>Paper</a> <a href='https://github.com/OATML-Markslab/Tranception' target='_blank'>Code</a> <a href='https://sites.google.com/view/proteingym/substitutions' target='_blank'>ProteinGym</a>")
378
+
379
+ if __name__ == "__main__":
380
+ tranception_design.queue()
381
+ tranception_design.launch()