Spaces:
Running
on
Zero
Running
on
Zero
#!/usr/bin/env python3 | |
""" | |
Tranception Design App - Hugging Face Spaces Version | |
""" | |
import os | |
import sys | |
import torch | |
import transformers | |
from transformers import PreTrainedTokenizerFast | |
import numpy as np | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import gradio as gr | |
from huggingface_hub import hf_hub_download | |
import zipfile | |
import shutil | |
import uuid | |
import tempfile | |
import atexit | |
import threading | |
import gc | |
# Add current directory to path | |
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) | |
# Check if we need to download and extract the tranception module | |
if not os.path.exists("tranception"): | |
print("Downloading Tranception repository...") | |
# Clone the repository structure | |
os.system("git clone https://github.com/OATML-Markslab/Tranception.git temp_tranception") | |
# Move the tranception module to current directory | |
shutil.move("temp_tranception/tranception", "tranception") | |
# Clean up | |
shutil.rmtree("temp_tranception") | |
import tranception | |
from tranception import config, model_pytorch | |
# Download model checkpoints if not present | |
def download_model_from_hf(model_name): | |
"""Download model from Hugging Face Hub if not present locally""" | |
model_path = f"./{model_name}" | |
if not os.path.exists(model_path): | |
print(f"Downloading {model_name} model...") | |
try: | |
# For Small and Medium models, they are available on HF Hub | |
if model_name in ["Tranception_Small", "Tranception_Medium"]: | |
return f"PascalNotin/{model_name}" | |
else: | |
# For Large model, we need to download from the original source | |
print("Note: Large model needs to be downloaded from the original source.") | |
print("Using Medium model as fallback...") | |
return "PascalNotin/Tranception_Medium" | |
except Exception as e: | |
print(f"Error downloading {model_name}: {e}") | |
return None | |
return model_path | |
AA_vocab = "ACDEFGHIKLMNPQRSTVWY" | |
tokenizer = PreTrainedTokenizerFast(tokenizer_file="./tranception/utils/tokenizers/Basic_tokenizer", | |
unk_token="[UNK]", | |
sep_token="[SEP]", | |
pad_token="[PAD]", | |
cls_token="[CLS]", | |
mask_token="[MASK]" | |
) | |
def create_all_single_mutants(sequence,AA_vocab=AA_vocab,mutation_range_start=None,mutation_range_end=None): | |
all_single_mutants={} | |
sequence_list=list(sequence) | |
if mutation_range_start is None: mutation_range_start=1 | |
if mutation_range_end is None: mutation_range_end=len(sequence) | |
for position,current_AA in enumerate(sequence[mutation_range_start-1:mutation_range_end]): | |
for mutated_AA in AA_vocab: | |
if current_AA!=mutated_AA: | |
mutated_sequence = sequence_list.copy() | |
mutated_sequence[mutation_range_start + position - 1] = mutated_AA | |
all_single_mutants[current_AA+str(mutation_range_start+position)+mutated_AA]="".join(mutated_sequence) | |
all_single_mutants = pd.DataFrame.from_dict(all_single_mutants,columns=['mutated_sequence'],orient='index') | |
all_single_mutants.reset_index(inplace=True) | |
all_single_mutants.columns = ['mutant','mutated_sequence'] | |
return all_single_mutants | |
def create_scoring_matrix_visual(scores,sequence,image_index=0,mutation_range_start=None,mutation_range_end=None,AA_vocab=AA_vocab,annotate=True,fontsize=20,unique_id=None): | |
if unique_id is None: | |
unique_id = str(uuid.uuid4()) | |
filtered_scores=scores.copy() | |
filtered_scores=filtered_scores[filtered_scores.position.isin(range(mutation_range_start,mutation_range_end+1))] | |
piv=filtered_scores.pivot(index='position',columns='target_AA',values='avg_score').round(4) | |
# Save CSV file | |
csv_path = 'fitness_scoring_substitution_matrix_{}_{}.csv'.format(unique_id, image_index) | |
# Create a more detailed CSV with mutation info | |
csv_data = [] | |
for position in range(mutation_range_start,mutation_range_end+1): | |
for target_AA in list(AA_vocab): | |
mutant = sequence[position-1]+str(position)+target_AA | |
if mutant in set(filtered_scores.mutant): | |
score_value = filtered_scores.loc[filtered_scores.mutant==mutant,'avg_score'] | |
if isinstance(score_value, pd.Series): | |
score = float(score_value.iloc[0]) | |
else: | |
score = float(score_value) | |
else: | |
score = 0.0 | |
csv_data.append({ | |
'position': position, | |
'original_AA': sequence[position-1], | |
'target_AA': target_AA, | |
'mutation': mutant, | |
'fitness_score': score | |
}) | |
csv_df = pd.DataFrame(csv_data) | |
csv_df.to_csv(csv_path, index=False) | |
# Continue with visualization | |
mutation_range_len = mutation_range_end - mutation_range_start + 1 | |
fig, ax = plt.subplots(figsize=(50,mutation_range_len)) | |
scores_dict = {} | |
valid_mutant_set=set(filtered_scores.mutant) | |
ax.tick_params(bottom=True, top=True, left=True, right=True) | |
ax.tick_params(labelbottom=True, labeltop=True, labelleft=True, labelright=True) | |
if annotate: | |
for position in range(mutation_range_start,mutation_range_end+1): | |
for target_AA in list(AA_vocab): | |
mutant = sequence[position-1]+str(position)+target_AA | |
if mutant in valid_mutant_set: | |
score_value = filtered_scores.loc[filtered_scores.mutant==mutant,'avg_score'] | |
if isinstance(score_value, pd.Series): | |
scores_dict[mutant] = float(score_value.iloc[0]) | |
else: | |
scores_dict[mutant] = float(score_value) | |
else: | |
scores_dict[mutant]=0.0 | |
labels = (np.asarray(["{} \n {:.4f}".format(symb,value) for symb, value in scores_dict.items() ])).reshape(mutation_range_len,len(AA_vocab)) | |
heat = sns.heatmap(piv,annot=labels,fmt="",cmap='RdYlGn',linewidths=0.30,ax=ax,vmin=np.percentile(scores.avg_score,2),vmax=np.percentile(scores.avg_score,98),\ | |
cbar_kws={'label': 'Log likelihood ratio (mutant / starting sequence)'},annot_kws={"size": fontsize}) | |
else: | |
heat = sns.heatmap(piv,cmap='RdYlGn',linewidths=0.30,ax=ax,vmin=np.percentile(scores.avg_score,2),vmax=np.percentile(scores.avg_score,98),\ | |
cbar_kws={'label': 'Log likelihood ratio (mutant / starting sequence)'},annot_kws={"size": fontsize}) | |
heat.figure.axes[-1].yaxis.label.set_size(fontsize=int(fontsize*1.5)) | |
heat.set_title("Higher predicted scores (green) imply higher protein fitness",fontsize=fontsize*2, pad=40) | |
heat.set_ylabel("Sequence position", fontsize = fontsize*2) | |
heat.set_xlabel("Amino Acid mutation", fontsize = fontsize*2) | |
# Set y-axis labels (positions) | |
yticklabels = [str(pos)+' ('+sequence[pos-1]+')' for pos in range(mutation_range_start,mutation_range_end+1)] | |
heat.set_yticklabels(yticklabels, fontsize=fontsize, rotation=0) | |
# Set x-axis labels (amino acids) - ensuring correct number | |
heat.set_xticklabels(list(AA_vocab), fontsize=fontsize) | |
plt.tight_layout() | |
image_path = 'fitness_scoring_substitution_matrix_{}_{}.png'.format(unique_id, image_index) | |
plt.savefig(image_path,dpi=100) | |
plt.close() | |
return image_path, csv_path | |
def suggest_mutations(scores): | |
intro_message = "The following mutations may be sensible options to improve fitness: \n\n" | |
#Best mutants | |
top_mutants=list(scores.sort_values(by=['avg_score'],ascending=False).head(5).mutant) | |
top_mutants_fitness=list(scores.sort_values(by=['avg_score'],ascending=False).head(5).avg_score) | |
top_mutants_recos = [top_mutant+" ("+str(round(top_mutant_fitness,4))+")" for (top_mutant,top_mutant_fitness) in zip(top_mutants,top_mutants_fitness)] | |
mutant_recos = "The single mutants with highest predicted fitness are (positive scores indicate fitness increase Vs starting sequence, negative scores indicate fitness decrease):\n {} \n\n".format(", ".join(top_mutants_recos)) | |
#Best positions | |
positive_scores = scores[scores.avg_score > 0] | |
if len(positive_scores) > 0: | |
# Only select numeric columns for groupby mean | |
positive_scores_position_avg = positive_scores.groupby(['position'])['avg_score'].mean().reset_index() | |
top_positions=list(positive_scores_position_avg.sort_values(by=['avg_score'],ascending=False).head(5)['position'].astype(str)) | |
position_recos = "The positions with the highest average fitness increase are (only positions with at least one fitness increase are considered):\n {}".format(", ".join(top_positions)) | |
else: | |
position_recos = "No positions with positive fitness effects found." | |
return intro_message+mutant_recos+position_recos | |
def check_valid_mutant(sequence,mutant,AA_vocab=AA_vocab): | |
valid = True | |
try: | |
from_AA, position, to_AA = mutant[0], int(mutant[1:-1]), mutant[-1] | |
except: | |
valid = False | |
if valid and position > 0 and position <= len(sequence): | |
if sequence[position-1]!=from_AA: valid=False | |
else: | |
valid = False | |
if to_AA not in AA_vocab: valid=False | |
return valid | |
# Global variable to track active inference threads | |
active_inferences = {} | |
inference_lock = threading.Lock() | |
def cleanup_old_files(max_age_minutes=30): | |
"""Clean up old inference files""" | |
import glob | |
import time | |
current_time = time.time() | |
patterns = ["fitness_scoring_substitution_matrix_*.png", | |
"fitness_scoring_substitution_matrix_*.csv", | |
"all_mutations_fitness_scores_*.csv"] | |
for pattern in patterns: | |
for file_path in glob.glob(pattern): | |
try: | |
file_age = current_time - os.path.getmtime(file_path) | |
if file_age > max_age_minutes * 60: | |
os.remove(file_path) | |
except: | |
pass | |
def get_mutated_protein(sequence,mutant): | |
if not check_valid_mutant(sequence,mutant): | |
return "The mutant is not valid" | |
mutated_sequence = list(sequence) | |
mutated_sequence[int(mutant[1:-1])-1]=mutant[-1] | |
return ''.join(mutated_sequence) | |
def score_and_create_matrix_all_singles(sequence,mutation_range_start=None,mutation_range_end=None,model_type="Large",scoring_mirror=False,batch_size_inference=20,max_number_positions_per_heatmap=50,num_workers=0,AA_vocab=AA_vocab): | |
# Clean up old files periodically | |
cleanup_old_files() | |
# Generate unique ID for this request | |
unique_id = str(uuid.uuid4()) | |
if mutation_range_start is None: mutation_range_start=1 | |
if mutation_range_end is None: mutation_range_end=len(sequence) | |
# Clean sequence | |
sequence = sequence.strip().upper() | |
# Validate | |
assert len(sequence) > 0, "no sequence entered" | |
assert mutation_range_start <= mutation_range_end, "mutation range is invalid" | |
assert mutation_range_end <= len(sequence), f"End position ({mutation_range_end}) exceeds sequence length ({len(sequence)})" | |
# Load model with HF Space compatibility | |
if model_type=="Small": | |
model_path = download_model_from_hf("Tranception_Small") | |
model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path=model_path) | |
elif model_type=="Medium": | |
model_path = download_model_from_hf("Tranception_Medium") | |
model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path=model_path) | |
elif model_type=="Large": | |
# For HF Spaces, we recommend using Medium model due to memory constraints | |
print("Note: Large model requires significant memory. Using Medium model for HF Spaces deployment.") | |
model_path = download_model_from_hf("Tranception_Medium") | |
model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path=model_path) | |
# Device selection - for HF Spaces, typically CPU | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
model.cuda() | |
print("Inference will take place on NVIDIA GPU") | |
else: | |
device = torch.device("cpu") | |
model.to(device) | |
print("Inference will take place on CPU") | |
# Reduce batch size for CPU inference | |
batch_size_inference = min(batch_size_inference, 10) | |
model.eval() | |
model.config.tokenizer = tokenizer | |
all_single_mutants = create_all_single_mutants(sequence,AA_vocab,mutation_range_start,mutation_range_end) | |
with torch.no_grad(): | |
scores = model.score_mutants(DMS_data=all_single_mutants, | |
target_seq=sequence, | |
scoring_mirror=scoring_mirror, | |
batch_size_inference=batch_size_inference, | |
num_workers=num_workers, | |
indel_mode=False | |
) | |
scores = pd.merge(scores,all_single_mutants,on="mutated_sequence",how="left") | |
scores["position"]=scores["mutant"].map(lambda x: int(x[1:-1])) | |
scores["target_AA"] = scores["mutant"].map(lambda x: x[-1]) | |
score_heatmaps = [] | |
csv_files = [] | |
mutation_range = mutation_range_end - mutation_range_start + 1 | |
number_heatmaps = int((mutation_range - 1) / max_number_positions_per_heatmap) + 1 | |
image_index = 0 | |
window_start = mutation_range_start | |
window_end = min(mutation_range_end,mutation_range_start+max_number_positions_per_heatmap-1) | |
for image_index in range(number_heatmaps): | |
image_path, csv_path = create_scoring_matrix_visual(scores,sequence,image_index,window_start,window_end,AA_vocab,unique_id=unique_id) | |
score_heatmaps.append(image_path) | |
csv_files.append(csv_path) | |
window_start += max_number_positions_per_heatmap | |
window_end = min(mutation_range_end,window_start+max_number_positions_per_heatmap-1) | |
# Also save a comprehensive CSV with all mutations | |
comprehensive_csv_path = 'all_mutations_fitness_scores_{}.csv'.format(unique_id) | |
scores_export = scores[['mutant', 'position', 'target_AA', 'avg_score', 'mutated_sequence']].copy() | |
scores_export['original_AA'] = scores_export['mutant'].str[0] | |
scores_export = scores_export.rename(columns={'avg_score': 'fitness_score'}) | |
scores_export = scores_export[['position', 'original_AA', 'target_AA', 'mutant', 'fitness_score', 'mutated_sequence']] | |
scores_export.to_csv(comprehensive_csv_path, index=False) | |
csv_files.append(comprehensive_csv_path) | |
# Clean up model from memory after inference | |
del model | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
return score_heatmaps, suggest_mutations(scores), csv_files | |
def extract_sequence(protein_id, taxon, sequence): | |
return sequence | |
def clear_inputs(protein_sequence_input,mutation_range_start,mutation_range_end): | |
protein_sequence_input = "" | |
mutation_range_start = None | |
mutation_range_end = None | |
return protein_sequence_input,mutation_range_start,mutation_range_end | |
# Create Gradio app | |
tranception_design = gr.Blocks() | |
with tranception_design: | |
gr.Markdown("# In silico directed evolution for protein redesign with Tranception") | |
gr.Markdown("## 🧬 Hugging Face Spaces Demo") | |
gr.Markdown("Perform in silico directed evolution with Tranception to iteratively improve the fitness of a protein of interest, one mutation at a time. At each step, the Tranception model computes the log likelihood ratios of all possible single amino acid substitution Vs the starting sequence, and outputs a fitness heatmap and recommandations to guide the selection of the mutation to apply.") | |
gr.Markdown("**Note**: This demo runs on CPU in Hugging Face Spaces. For faster inference, consider using GPU locally or selecting the Small model.") | |
with gr.Tabs(): | |
with gr.TabItem("Input"): | |
with gr.Row(): | |
protein_sequence_input = gr.Textbox(lines=1, | |
label="Protein sequence", | |
placeholder = "Input the sequence of amino acids representing the starting protein of interest or select one from the list of examples below. You may enter the full sequence or just a subdomain (providing full context typically leads to better results, but is slower at inference)" | |
) | |
with gr.Row(): | |
mutation_range_start = gr.Number(label="Start of mutation window (first position indexed at 1)", value=1, precision=0) | |
mutation_range_end = gr.Number(label="End of mutation window (leave empty for full lenth)", value=10, precision=0) | |
with gr.TabItem("Parameters"): | |
with gr.Row(): | |
model_size_selection = gr.Radio(label="Tranception model size (larger models are more accurate but are slower at inference)", | |
choices=["Small","Medium","Large"], | |
value="Small") | |
with gr.Row(): | |
scoring_mirror = gr.Checkbox(label="Score protein from both directions (leads to more robust fitness predictions, but doubles inference time)") | |
with gr.Row(): | |
batch_size_inference = gr.Number(label="Model batch size at inference time (reduce for CPU)",value = 10, precision=0) | |
with gr.Row(): | |
gr.Markdown("Note: the current version does not leverage retrieval of homologs at inference time to increase fitness prediction performance.") | |
with gr.Row(): | |
clear_button = gr.Button(value="Clear",variant="secondary") | |
run_button = gr.Button(value="Predict fitness",variant="primary") | |
protein_ID = gr.Textbox(label="Uniprot ID", visible=False) | |
taxon = gr.Textbox(label="Taxon", visible=False) | |
examples = gr.Examples( | |
inputs=[protein_ID, taxon, protein_sequence_input], | |
outputs=[protein_sequence_input], | |
fn=extract_sequence, | |
examples=[ | |
['ADRB2_HUMAN' ,'Human', 'MGQPGNGSAFLLAPNGSHAPDHDVTQERDEVWVVGMGIVMSLIVLAIVFGNVLVITAIAKFERLQTVTNYFITSLACADLVMGLAVVPFGAAHILMKMWTFGNFWCEFWTSIDVLCVTASIETLCVIAVDRYFAITSPFKYQSLLTKNKARVIILMVWIVSGLTSFLPIQMHWYRATHQEAINCYANETCCDFFTNQAYAIASSIVSFYVPLVIMVFVYSRVFQEAKRQLQKIDKSEGRFHVQNLSQVEQDGRTGHGLRRSSKFCLKEHKALKTLGIIMGTFTLCWLPFFIVNIVHVIQDNLIRKEVYILLNWIGYVNSGFNPLIYCRSPDFRIAFQELLCLRRSSLKAYGNGYSSNGNTGEQSGYHVEQEKENKLLCEDLPGTEDFVGHQGTVPSDNIDSQGRNCSTNDSLL'], | |
['IF1_ECOLI' ,'Prokaryote', 'MAKEDNIEMQGTVLETLPNTMFRVELENGHVVTAHISGKMRKNYIRILTGDKVTVELTPYDLSKGRIVFRSR'], | |
['P53_HUMAN' ,'Human', 'MEEPQSDPSVEPPLSQETFSDLWKLLPENNVLSPLPSQAMDDLMLSPDDIEQWFTEDPGPDEAPRMPEAAPRVAPAPAAPTPAAPAPAPSWPLSSSVPSQKTYQGSYGFRLGFLHSGTAKSVTCTYSPALNKMFCQLAKTCPVQLWVDSTPPPGTRVRAMAIYKQSQHMTEVVRRCPHHERCSDSDGLAPPQHLIRVEGNLRVEYLDDRNTFRHSVVVPYEPPEVGSDCTTIHYNYMCNSSCMGGMNRRPILTIITLEDSSGNLLGRNSFEVRVCACPGRDRRTEEENLRKKGEPHHELPPGSTKRALPNNTSSSPQPKKKPLDGEYFTLQIRGRERFEMFRELNEALELKDAQAGKEPGGSRAHSSHLKSKKGQSTSRHKKLMFKTEGPDSD'], | |
['BLAT_ECOLX' ,'Prokaryote', 'MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIELDLNSGKILESFRPEERFPMMSTFKVLLCGAVLSRVDAGQEQLGRRIHYSQNDLVEYSPVTEKHLTDGMTVRELCSAAITMSDNTAANLLLTTIGGPKELTAFLHNMGDHVTRLDRWEPELNEAIPNDERDTTMPAAMATTLRKLLTGELLTLASRQQLIDWMEADKVAGPLLRSALPAGWFIADKSGAGERGSRGIIAALGPDGKPSRIVVIYTTGSQATMDERNRQIAEIGASLIKHW'], | |
['BRCA1_HUMAN' ,'Human', 'MDLSALRVEEVQNVINAMQKILECPICLELIKEPVSTKCDHIFCKFCMLKLLNQKKGPSQCPLCKNDITKRSLQESTRFSQLVEELLKIICAFQLDTGLEYANSYNFAKKENNSPEHLKDEVSIIQSMGYRNRAKRLLQSEPENPSLQETSLSVQLSNLGTVRTLRTKQRIQPQKTSVYIELGSDSSEDTVNKATYCSVGDQELLQITPQGTRDEISLDSAKKAACEFSETDVTNTEHHQPSNNDLNTTEKRAAERHPEKYQGSSVSNLHVEPCGTNTHASSLQHENSSLLLTKDRMNVEKAEFCNKSKQPGLARSQHNRWAGSKETCNDRRTPSTEKKVDLNADPLCERKEWNKQKLPCSENPRDTEDVPWITLNSSIQKVNEWFSRSDELLGSDDSHDGESESNAKVADVLDVLNEVDEYSGSSEKIDLLASDPHEALICKSERVHSKSVESNIEDKIFGKTYRKKASLPNLSHVTENLIIGAFVTEPQIIQERPLTNKLKRKRRPTSGLHPEDFIKKADLAVQKTPEMINQGTNQTEQNGQVMNITNSGHENKTKGDSIQNEKNPNPIESLEKESAFKTKAEPISSSISNMELELNIHNSKAPKKNRLRRKSSTRHIHALELVVSRNLSPPNCTELQIDSCSSSEEIKKKKYNQMPVRHSRNLQLMEGKEPATGAKKSNKPNEQTSKRHDSDTFPELKLTNAPGSFTKCSNTSELKEFVNPSLPREEKEEKLETVKVSNNAEDPKDLMLSGERVLQTERSVESSSISLVPGTDYGTQESISLLEVSTLGKAKTEPNKCVSQCAAFENPKGLIHGCSKDNRNDTEGFKYPLGHEVNHSRETSIEMEESELDAQYLQNTFKVSKRQSFAPFSNPGNAEEECATFSAHSGSLKKQSPKVTFECEQKEENQGKNESNIKPVQTVNITAGFPVVGQKDKPVDNAKCSIKGGSRFCLSSQFRGNETGLITPNKHGLLQNPYRIPPLFPIKSFVKTKCKKNLLEENFEEHSMSPEREMGNENIPSTVSTISRNNIRENVFKEASSSNINEVGSSTNEVGSSINEIGSSDENIQAELGRNRGPKLNAMLRLGVLQPEVYKQSLPGSNCKHPEIKKQEYEEVVQTVNTDFSPYLISDNLEQPMGSSHASQVCSETPDDLLDDGEIKEDTSFAENDIKESSAVFSKSVQKGELSRSPSPFTHTHLAQGYRRGAKKLESSEENLSSEDEELPCFQHLLFGKVNNIPSQSTRHSTVATECLSKNTEENLLSLKNSLNDCSNQVILAKASQEHHLSEETKCSASLFSSQCSELEDLTANTNTQDPFLIGSSKQMRHQSESQGVGLSDKELVSDDEERGTGLEENNQEEQSMDSNLGEAASGCESETSVSEDCSGLSSQSDILTTQQRDTMQHNLIKLQQEMAELEAVLEQHGSQPSNSYPSIISDSSALEDLRNPEQSTSEKAVLTSQKSSEYPISQNPEGLSADKFEVSADSSTSKNKEPGVERSSPSKCPSLDDRWYMHSCSGSLQNRNYPSQEELIKVVDVEEQQLEESGPHDLTETSYLPRQDLEGTPYLESGISLFSDDPESDPSEDRAPESARVGNIPSSTSALKVPQLKVAESAQSPAAAHTTDTAGYNAMEESVSREKPELTASTERVNKRMSMVVSGLTPEEFMLVYKFARKHHITLTNLITEETTHVVMKTDAEFVCERTLKYFLGIAGGKWVVSYFWVTQSIKERKMLNEHDFEVRGDVVNGRNHQGPKRARESQDRKIFRGLEICCYGPFTNMPTDQLEWMVQLCGASVVKELSSFTLGTGVHPIVVVQPDAWTEDNGFHAIGQMCEAPVVTREWVLDSVALYQCQELDTYLIPQIPHSHY'], | |
['CALM1_HUMAN' ,'Human', 'MADQLTEEQIAEFKEAFSLFDKDGDGTITTKELGTVMRSLGQNPTEAELQDMINEVDADGNGTIDFPEFLTMMARKMKDTDSEEEIREAFRVFDKDGNGYISAAELRHVMTNLGEKLTDEEVDEMIREADIDGDGQVNYEEFVQMMTAK'], | |
['CCDB_ECOLI' ,'Prokaryote', 'MQFKVYTYKRESRYRLFVDVQSDIIDTPGRRMVIPLASARLLSDKVSRELYPVVHIGDESWRMMTTDMASVPVSVIGEEVADLSHRENDIKNAINLMFWGI'], | |
['GFP_AEQVI' ,'Other eukaryote', 'MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTLSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK'], | |
['GRB2_HUMAN' ,'Human', 'MEAIAKYDFKATADDELSFKRGDILKVLNEECDQNWYKAELNGKDGFIPKNYIEMKPHPWFFGKIPRAKAEEMLSKQRHDGAFLIRESESAPGDFSLSVKFGNDVQHFKVLRDGAGKYFLWVVKFNSLNELVDYHRSTSVSRNQQIFLRDIEQVPQQPTYVQALFDFDPQEDGELGFRRGDFIHVMDNSDPNWWKGACHGQTGMFPRNYVTPVNRNV'], | |
], | |
) | |
gr.Markdown("<br>") | |
gr.Markdown("# Fitness predictions for all single amino acid substitutions in mutation range") | |
gr.Markdown("Inference may take a few seconds for short proteins & mutation ranges to several minutes for longer ones") | |
output_image = gr.Gallery(label="Fitness predictions for all single amino acid substitutions in mutation range") #Using Gallery to break down large scoring matrices into smaller images | |
output_recommendations = gr.Textbox(label="Mutation recommendations") | |
with gr.Row(): | |
gr.Markdown("## Download CSV Files") | |
output_csv_files = gr.File(label="Download CSV files with fitness scores", file_count="multiple", interactive=False) | |
clear_button.click( | |
inputs = [protein_sequence_input,mutation_range_start,mutation_range_end], | |
outputs = [protein_sequence_input,mutation_range_start,mutation_range_end], | |
fn=clear_inputs | |
) | |
run_button.click( | |
fn=score_and_create_matrix_all_singles, | |
inputs=[protein_sequence_input,mutation_range_start,mutation_range_end,model_size_selection,scoring_mirror,batch_size_inference], | |
outputs=[output_image,output_recommendations,output_csv_files], | |
) | |
gr.Markdown("# Mutate the starting protein sequence") | |
with gr.Row(): | |
mutation_triplet = gr.Textbox(lines=1,label="Selected mutation", placeholder = "Input the mutation triplet for the selected mutation (eg., M1A)") | |
mutate_button = gr.Button(value="Apply mutation to starting protein", variant="primary") | |
mutated_protein_sequence = gr.Textbox(lines=1,label="Mutated protein sequence") | |
mutate_button.click( | |
fn = get_mutated_protein, | |
inputs = [protein_sequence_input,mutation_triplet], | |
outputs = mutated_protein_sequence | |
) | |
gr.Markdown("<p>You may now use the output mutated sequence above as the starting sequence for another round of in silico directed evolution.</p>") | |
gr.Markdown("For more information about the Tranception model, please refer to our paper below:") | |
gr.Markdown("<p><b>Tranception: Protein Fitness Prediction with Autoregressive Transformers and Inference-time Retrieval</b><br>Pascal Notin, Mafalda Dias, Jonathan Frazer, Javier Marchena-Hurtado, Aidan N. Gomez, Debora S. Marks<sup>*</sup>, Yarin Gal<sup>*</sup><br><sup>* equal senior authorship</sup></p>") | |
gr.Markdown("Links: <a href='https://proceedings.mlr.press/v162/notin22a.html' target='_blank'>Paper</a> <a href='https://github.com/OATML-Markslab/Tranception' target='_blank'>Code</a> <a href='https://sites.google.com/view/proteingym/substitutions' target='_blank'>ProteinGym</a>") | |
if __name__ == "__main__": | |
# Configure queue for better resource management | |
tranception_design.queue( | |
max_size=10, # Limit queue size | |
status_update_rate="auto", # Show status updates | |
api_open=False # Disable API to prevent external requests | |
) | |
# Launch with appropriate settings for HF Spaces | |
tranception_design.launch( | |
max_threads=2, # Limit concurrent threads | |
show_error=True, | |
server_name="0.0.0.0", | |
server_port=7860 | |
) |