Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import torch | |
| import esm | |
| import requests | |
| import matplotlib.pyplot as plt | |
| from clickhouse_connect import get_client | |
| import random | |
| from collections import Counter | |
| from tqdm import tqdm | |
| from statistics import mean | |
| import biotite.structure.io as bsio | |
| import torch | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import pandas as pd | |
| import seaborn as sns | |
| from stmol import * | |
| import py3Dmol | |
| # from streamlit_3Dmol import component_3dmol | |
| import scipy | |
| from sklearn.model_selection import GridSearchCV, train_test_split | |
| from sklearn.decomposition import PCA | |
| from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor | |
| from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor | |
| from sklearn.linear_model import LogisticRegression, SGDRegressor | |
| from sklearn.pipeline import Pipeline | |
| from streamlit.components.v1 import html | |
| def init_esm(): | |
| msa_transformer, msa_transformer_alphabet = esm.pretrained.esm_msa1b_t12_100M_UR50S() | |
| msa_transformer = msa_transformer.eval() | |
| return msa_transformer, msa_transformer_alphabet | |
| def init_db(): | |
| """ Initialize the Database Connection | |
| Returns: | |
| meta_field: Meta field that records if an image is viewed | |
| client: Database connection object | |
| """ | |
| r = parse("{http_pre}://{host}:{port}", st.secrets["DB_URL"]) | |
| client = get_client( | |
| host=r['host'], port=r['port'], user=st.secrets["USER"], password=st.secrets["PASSWD"], | |
| interface=r['http_pre'], | |
| ) | |
| meta_field = {} | |
| return meta_field, Client | |
| def perdict_contact_visualization(seq, model, batch_converter): | |
| data = [ | |
| ("protein1", seq), | |
| ] | |
| batch_labels, batch_strs, batch_tokens = batch_converter(data) | |
| # Extract per-residue representations (on CPU) | |
| with torch.no_grad(): | |
| results = model(batch_tokens, repr_layers=[12], return_contacts=True) | |
| token_representations = results["representations"][12] | |
| # Generate per-sequence representations via averaging | |
| # NOTE: token 0 is always a beginning-of-sequence token, so the first residue is token 1. | |
| sequence_representations = [] | |
| for i, (_, seq) in enumerate(data): | |
| sequence_representations.append(token_representations[i, 1 : len(seq) + 1].mean(0)) | |
| # Look at the unsupervised self-attention map contact predictions | |
| for (_, seq), attention_contacts in zip(data, results["contacts"]): | |
| fig, ax = plt.subplots() | |
| ax.matshow(attention_contacts[: len(seq), : len(seq)]) | |
| # fig.set_facecolor('black') | |
| return fig | |
| def visualize_3D_Coordinates(coords): | |
| xs = [] | |
| ys = [] | |
| zs = [] | |
| for i in coords: | |
| xs.append(i[0]) | |
| ys.append(i[1]) | |
| zs.append(i[2]) | |
| fig = plt.figure(figsize=(10,10)) | |
| ax = fig.add_subplot(111, projection='3d') | |
| ax.set_title('3D coordinates of $C_{b}$ backbone structure') | |
| N = len(coords) | |
| for i in range(len(coords) - 1): | |
| ax.plot( | |
| xs[i:i+2], ys[i:i+2], zs[i:i+2], | |
| color=plt.cm.viridis(i/N), | |
| marker='o' | |
| ) | |
| return fig | |
| def render_mol(pdb): | |
| pdbview = py3Dmol.view() | |
| pdbview.addModel(pdb,'pdb') | |
| pdbview.setStyle({'cartoon':{'color':'spectrum'}}) | |
| pdbview.setBackgroundColor('white')#('0xeeeeee') | |
| pdbview.zoomTo() | |
| pdbview.zoom(2, 800) | |
| pdbview.spin(True) | |
| showmol(pdbview, height = 500,width=800) | |
| def esm_search(model, sequnce, batch_converter,top_k=5): | |
| data = [ | |
| ("protein1", sequnce), | |
| ] | |
| batch_labels, batch_strs, batch_tokens = batch_converter(data) | |
| # Extract per-residue representations (on CPU) | |
| with torch.no_grad(): | |
| results = model(batch_tokens, repr_layers=[12], return_contacts=True) | |
| token_representations = results["representations"][12] | |
| token_list = token_representations.tolist()[0][0][0] | |
| result = st.session_state.client.query("SELECT seq, distance(representations, " + str(token_list) + ')'+ "as dist FROM default.esm_protein_indexer_768 ORDER BY dist LIMIT 500") | |
| result = [r for r in result.named_results()] | |
| result_temp_seq = [] | |
| for i in result: | |
| # result_temp_coords = i['seq'] | |
| result_temp_seq.append(i['seq']) | |
| result_temp_seq = list(set(result_temp_seq)) | |
| return result_temp_seq | |
| def show_protein_structure(sequence): | |
| headers = { | |
| 'Content-Type': 'application/x-www-form-urlencoded', | |
| } | |
| response = requests.post('https://api.esmatlas.com/foldSequence/v1/pdb/', headers=headers, data=sequence) | |
| name = sequence[:3] + sequence[-3:] | |
| pdb_string = response.content.decode('utf-8') | |
| with open('predicted.pdb', 'w') as f: | |
| f.write(pdb_string) | |
| struct = bsio.load_structure('predicted.pdb', extra_fields=["b_factor"]) | |
| b_value = round(struct.b_factor.mean(), 4) | |
| render_mol(pdb_string) | |
| def KNN_search(sequence): | |
| model, alphabet = esm.pretrained.esm2_t33_650M_UR50D() | |
| batch_converter = alphabet.get_batch_converter() | |
| model.eval() | |
| data = [("protein1", sequence), | |
| ] | |
| batch_labels, batch_strs, batch_tokens = batch_converter(data) | |
| batch_lens = (batch_tokens != alphabet.padding_idx).sum(1) | |
| with torch.no_grad(): | |
| results = model(batch_tokens, repr_layers=[33], return_contacts=True) | |
| token_representations = results["representations"][33] | |
| token_list = token_representations.tolist()[0][0] | |
| print(token_list) | |
| result = st.session_state.client.query("SELECT activity, distance(representations, " + str(token_list) + ')'+ "as dist FROM default.esm_protein_indexer ORDER BY dist LIMIT 10") | |
| result = [r for r in result.named_results()] | |
| result_temp_activity = [] | |
| for i in result: | |
| # print(result_temp_seq) | |
| result_temp_activity.append(i['activity']) | |
| res_1 = sum(result_temp_activity)/len(result_temp_activity) | |
| return res_1 | |
| def train_test_split_PCA(dataset): | |
| ys = [] | |
| Xs = [] | |
| FASTA_PATH = '/root/xuying_experiments/esm-main/P62593.fasta' | |
| EMB_PATH = '/root/xuying_experiments/esm-main/P62593_reprs' | |
| for header, _seq in esm.data.read_fasta(FASTA_PATH): | |
| scaled_effect = header.split('|')[-1] | |
| ys.append(float(scaled_effect)) | |
| fn = f'{EMB_PATH}/{header}.pt' | |
| embs = torch.load(fn) | |
| Xs.append(embs['mean_representations'][34]) | |
| Xs = torch.stack(Xs, dim=0).numpy() | |
| train_size = 0.8 | |
| Xs_train, Xs_test, ys_train, ys_test = train_test_split(Xs, ys, train_size=train_size, random_state=42) | |
| return Xs_train, Xs_test, ys_train, ys_test | |
| def PCA_visual(Xs_train): | |
| num_pca_components = 60 | |
| pca = PCA(num_pca_components) | |
| Xs_train_pca = pca.fit_transform(Xs_train) | |
| fig_dims = (4, 4) | |
| fig, ax = plt.subplots(figsize=fig_dims) | |
| ax.set_title('Visualize Embeddings') | |
| sc = ax.scatter(Xs_train_pca[:,0], Xs_train_pca[:,1], c=ys_train, marker='.') | |
| ax.set_xlabel('PCA first principal component') | |
| ax.set_ylabel('PCA second principal component') | |
| plt.colorbar(sc, label='Variant Effect') | |
| return fig | |
| def KNN_trainings(Xs_train, Xs_test, ys_train, ys_test): | |
| num_pca_components = 60 | |
| knn_grid = [ | |
| { | |
| 'model': [KNeighborsRegressor()], | |
| 'model__n_neighbors': [5, 10], | |
| 'model__weights': ['uniform', 'distance'], | |
| 'model__algorithm': ['ball_tree', 'kd_tree', 'brute'], | |
| 'model__leaf_size' : [15, 30], | |
| 'model__p' : [1, 2], | |
| }] | |
| cls_list = [KNeighborsRegressor] | |
| param_grid_list = [knn_grid] | |
| pipe = Pipeline( | |
| steps = ( | |
| ('pca', PCA(num_pca_components)), | |
| ('model', KNeighborsRegressor()) | |
| ) | |
| ) | |
| result_list = [] | |
| grid_list = [] | |
| for cls_name, param_grid in zip(cls_list, param_grid_list): | |
| print(cls_name) | |
| grid = GridSearchCV( | |
| estimator = pipe, | |
| param_grid = param_grid, | |
| scoring = 'r2', | |
| verbose = 1, | |
| n_jobs = -1 # use all available cores | |
| ) | |
| grid.fit(Xs_train, ys_train) | |
| # print(Xs_train, ys_train) | |
| result_list.append(pd.DataFrame.from_dict(grid.cv_results_)) | |
| grid_list.append(grid) | |
| dataframe = pd.DataFrame(result_list[0].sort_values('rank_test_score')[:5]) | |
| return dataframe[['param_model','params','param_model__algorithm','mean_test_score','rank_test_score']] | |
| st.markdown(""" | |
| <link | |
| rel="stylesheet" | |
| href="https://fonts.googleapis.com/css?family=Roboto:300,400,500,700&display=swap" | |
| /> | |
| """, unsafe_allow_html=True) | |
| messages = [ | |
| f""" | |
| Evolutionary-scale prediction of atomic level protein structure | |
| ESM is a high-capacity Transformer trained with protein sequences \ | |
| as input. After training, the secondary and tertiary structure, \ | |
| function, homology and other information of the protein are in the feature representation output by the model.\ | |
| Check out https://esmatlas.com/ for more information. | |
| We have 120k proteins features stored in our database. | |
| The app uses MyScale to store and query protein sequence | |
| using vector search. | |
| """ | |
| ] | |
| def init_random_query(): | |
| xq = np.random.rand(DIMS).tolist() | |
| return xq, xq.copy() | |
| with st.spinner("Connecting DB..."): | |
| st.session_state.meta, st.session_state.client = init_db() | |
| with st.spinner("Loading Models..."): | |
| # Initialize SAGE model | |
| if 'xq' not in st.session_state: | |
| st.session_state.model, st.session_state.alphabet = init_esm() | |
| batch_converter = st.session_state.alphabet.get_batch_converter() | |
| st.session_state['batch'] = batch_converter | |
| st.session_state.batch_converter = st.session_state.alphabet.get_batch_converter() | |
| st.session_state.query_num = 0 | |
| if 'xq' not in st.session_state: | |
| # If it's a fresh start | |
| if st.session_state.query_num < len(messages): | |
| msg = messages[0] | |
| else: | |
| msg = messages[-1] | |
| with st.container(): | |
| st.title("Evolutionary Scale Modeling") | |
| start = [st.empty(), st.empty(), st.empty(), st.empty(), st.empty(), st.empty(), st.empty()] | |
| start[0].info(msg) | |
| function_list = ('self-contact prediction', | |
| 'search the database for similar proteins', | |
| 'activity prediction with similar proteins', | |
| 'PDB viewer') | |
| option = st.selectbox('Application options', function_list) | |
| st.session_state.db_name_ref = 'default.esm_protein' | |
| if option == function_list[0]: | |
| sequence = st.text_input('protein sequence(Capital letters only)', '') | |
| if st.button('Cas9 Enzyme'): | |
| sequence = 'GSGHMDKKYSIGLAIGTNSVGWAVITDEYKVPSKKFKVLGNTDRHSIKKNLIGALLFDSGETAEATRLKRTARRRYTRRKNRILYLQEIFSNEMAKV' | |
| elif st.button('PETase'): | |
| sequence = 'MGSSHHHHHHSSGLVPRGSHMRGPNPTAASLEASAGPFTVRSFTVSRPSGYGAGTVYYPTNAGGTVGAIAIVPGYTARQSSIKWWGPRLASHGFVVITIDTNSTLDQPSSRSSQQMAALRQVASLNGTSSSPIYGKVDTARMGVMGWSMGGGGSLISAANNPSLKAAAPQAPWDSSTNFSSVTVPTLIFACENDSIAPVNSSALPIYDSMSRNAKQFLEINGGSHSCANSGNSNQALIGKKGVAWMKRFMDNDTRYSTFACENPNSTRVSDFRTANCSLEDPAANKARKEAELAAATAEQ' | |
| if sequence: | |
| st.write('') | |
| start[2] = st.pyplot(perdict_contact_visualization(sequence, model, batch_converter)) | |
| expander = st.expander("See explanation") | |
| expander.text("""Contact prediction is based on a logistic regression over the model's attention maps. \ | |
| This methodology is based on ICLR 2021 paper, Transformer protein language models are unsupervised structure learners. | |
| (Rao et al. 2020) The MSA Transformer (ESM-MSA-1) takes a multiple sequence alignment (MSA) as input, and uses the tied row self-attention maps in the same way.""") | |
| st.session_state['xq'] = st.session_state.model | |
| elif option == function_list[1]: | |
| sequence = st.text_input('protein sequence(Capital letters only)', '') | |
| st.write('Try an example:') | |
| if st.button('Cas9 Enzyme'): | |
| sequence = 'GSGHMDKKYSIGLAIGTNSVGWAVITDEYKVPSKKFKVLGNTDRHSIKKNLIGALLFDSGETAEATRLKRTARRRYTRRKNRILYLQEIFSNEMAKV' | |
| elif st.button('PETase'): | |
| sequence = 'MGSSHHHHHHSSGLVPRGSHMRGPNPTAASLEASAGPFTVRSFTVSRPSGYGAGTVYYPTNAGGTVGAIAIVPGYTARQSSIKWWGPRLASHGFVVITIDTNSTLDQPSSRSSQQMAALRQVASLNGTSSSPIYGKVDTARMGVMGWSMGGGGSLISAANNPSLKAAAPQAPWDSSTNFSSVTVPTLIFACENDSIAPVNSSALPIYDSMSRNAKQFLEINGGSHSCANSGNSNQALIGKKGVAWMKRFMDNDTRYSTFACENPNSTRVSDFRTANCSLEDPAANKARKEAELAAATAEQ' | |
| if sequence: | |
| st.write('you have entered: ', sequence) | |
| result_temp_seq = esm_search(model, sequence, esm_search,top_k=5) | |
| st.text('search result: ') | |
| # tab1, tab2, tab3, tab4, = st.tabs(["Cat", "Dog", "Owl"]) | |
| if st.button(result_temp_seq[0]): | |
| print(result_temp_seq[0]) | |
| elif st.button(result_temp_seq[1]): | |
| print(result_temp_seq[1]) | |
| elif st.button(result_temp_seq[2]): | |
| print(result_temp_seq[2]) | |
| elif st.button(result_temp_seq[3]): | |
| print(result_temp_seq[3]) | |
| elif st.button(result_temp_seq[4]): | |
| print(result_temp_seq[4]) | |
| start[2] = st.pyplot(visualize_3D_Coordinates(result_temp_coords).figure) | |
| st.session_state['xq'] = st.session_state.model | |
| elif option == function_list[2]: | |
| st.text('we predict the biological activity of mutations of a protein, using fixed embeddings from ESM.') | |
| sequence = st.text_input('protein sequence', '') | |
| st.write('Try an example:') | |
| if st.button('Cas9 Enzyme'): | |
| sequence = 'GSGHMDKKYSIGLAIGTNSVGWAVITDEYKVPSKKFKVLGNTDRHSIKKNLIGALLFDSGETAEATRLKRTARRRYTRRKNRILYLQEIFSNEMAKV' | |
| elif st.button('PETase'): | |
| sequence = 'MGSSHHHHHHSSGLVPRGSHMRGPNPTAASLEASAGPFTVRSFTVSRPSGYGAGTVYYPTNAGGTVGAIAIVPGYTARQSSIKWWGPRLASHGFVVITIDTNSTLDQPSSRSSQQMAALRQVASLNGTSSSPIYGKVDTARMGVMGWSMGGGGSLISAANNPSLKAAAPQAPWDSSTNFSSVTVPTLIFACENDSIAPVNSSALPIYDSMSRNAKQFLEINGGSHSCANSGNSNQALIGKKGVAWMKRFMDNDTRYSTFACENPNSTRVSDFRTANCSLEDPAANKARKEAELAAATAEQ' | |
| elif option == function_list[3]: | |
| id_PDB = st.text_input('enter PDB ID', '') | |
| residues_marker = st.text_input('residues class', '') | |
| if residues_marker: | |
| start[3] = showmol(render_pdb_resn(viewer = render_pdb(id = id_PDB),resn_lst = [residues_marker])) | |
| else: | |
| start[3] = showmol(render_pdb(id = id_PDB)) | |
| st.session_state['xq'] = st.session_state.model | |
| else: | |
| if st.session_state.query_num < len(messages): | |
| msg = messages[0] | |
| else: | |
| msg = messages[-1] | |
| with st.container(): | |
| st.title("Evolutionary Scale Modeling") | |
| start = [st.empty(), st.empty(), st.empty(), st.empty(), st.empty(), st.empty(), st.empty(), st.empty(), st.empty()] | |
| start[0].info(msg) | |
| option = st.selectbox('Application options', ('self-contact prediction', | |
| 'search the database for similar proteins', | |
| 'activity prediction with similar proteins', | |
| 'PDB viewer')) | |
| st.session_state.db_name_ref = 'default.esm_protein' | |
| if option == 'self-contact prediction': | |
| sequence = st.text_input('protein sequence(Capital letters only)', '') | |
| if st.button('Cas9 Enzyme'): | |
| sequence = 'GSGHMDKKYSIGLAIGTNSVGWAVITDEYKVPSKKFKVLGNTDRHSIKKNLIGALLFDSGETAEATRLKRTARRRYTRRKNRILYLQEIFSNEMAKV' | |
| elif st.button('PETase'): | |
| sequence = 'MGSSHHHHHHSSGLVPRGSHMRGPNPTAASLEASAGPFTVRSFTVSRPSGYGAGTVYYPTNAGGTVGAIAIVPGYTARQSSIKWWGPRLASHGFVVITIDTNSTLDQPSSRSSQQMAALRQVASLNGTSSSPIYGKVDTARMGVMGWSMGGGGSLISAANNPSLKAAAPQAPWDSSTNFSSVTVPTLIFACENDSIAPVNSSALPIYDSMSRNAKQFLEINGGSHSCANSGNSNQALIGKKGVAWMKRFMDNDTRYSTFACENPNSTRVSDFRTANCSLEDPAANKARKEAELAAATAEQ' | |
| if sequence: | |
| st.write('you have entered: ',sequence) | |
| start[2] = st.pyplot(perdict_contact_visualization(sequence, st.session_state['xq'], st.session_state['batch'])) | |
| expander = st.expander("See explanation") | |
| expander.markdown( | |
| """<span style="word-wrap:break-word;">Contact prediction is based on a logistic regression over the model's attention maps. This methodology is based on ICLR 2021 paper, Transformer protein language models are unsupervised structure learners. (Rao et al. 2020)The MSA Transformer (ESM-MSA-1) takes a multiple sequence alignment (MSA) as input, and uses the tied row self-attention maps in the same way.</span> | |
| """, unsafe_allow_html=True) | |
| elif option == 'search the database for similar proteins': | |
| sequence = st.text_input('protein sequence(Capital letters only)', '') | |
| st.write('Try an example:') | |
| if st.button('Cas9 Enzyme'): | |
| sequence = 'GSGHMDKKYSIGLAIGTNSVGWAVITDEYKVPSKKFKVLGNTDRHSIKKNLIGALLFDSGETAEATRLKRTARRRYTRRKNRILYLQEIFSNEMAKV' | |
| elif st.button('PETase'): | |
| sequence = 'MGSSHHHHHHSSGLVPRGSHMRGPNPTAASLEASAGPFTVRSFTVSRPSGYGAGTVYYPTNAGGTVGAIAIVPGYTARQSSIKWWGPRLASHGFVVITIDTNSTLDQPSSRSSQQMAALRQVASLNGTSSSPIYGKVDTARMGVMGWSMGGGGSLISAANNPSLKAAAPQAPWDSSTNFSSVTVPTLIFACENDSIAPVNSSALPIYDSMSRNAKQFLEINGGSHSCANSGNSNQALIGKKGVAWMKRFMDNDTRYSTFACENPNSTRVSDFRTANCSLEDPAANKARKEAELAAATAEQ' | |
| if sequence: | |
| st.write('you have entered: ', sequence) | |
| result_temp_seq = esm_search(st.session_state.model, sequence, st.session_state.batch_converter ,top_k=10) | |
| st.text('search result (top 5): ') | |
| # tab1, tab2, tab3, tab4, = st.tabs(["Cat", "Dog", "Owl"]) | |
| tab1, tab2, tab3 , tab4, tab5 = st.tabs(['1','2','3','4','5']) | |
| with tab1: | |
| st.write(result_temp_seq[0]) | |
| show_protein_structure(result_temp_seq[0]) | |
| with tab2: | |
| st.write(result_temp_seq[1]) | |
| show_protein_structure(result_temp_seq[1]) | |
| with tab3: | |
| st.write(result_temp_seq[2]) | |
| show_protein_structure(result_temp_seq[2]) | |
| with tab4: | |
| st.write(result_temp_seq[3]) | |
| show_protein_structure(result_temp_seq[3]) | |
| with tab5: | |
| st.write(result_temp_seq[4]) | |
| show_protein_structure(result_temp_seq[4]) | |
| elif option == 'activity prediction with similar proteins': | |
| st.markdown('we predict the biological activity of mutations of a protein, using fixed embeddings from ESM.') | |
| # st.text('we predict the biological activity of mutations of a protein, using fixed embeddings from ESM.') | |
| sequence = st.text_input('protein sequence', '') | |
| st.write('Try an example:') | |
| if st.button('Cas9 Enzyme'): | |
| sequence = 'GSGHMDKKYSIGLAIGTNSVGWAVITDEYKVPSKKFKVLGNTDRHSIKKNLIGALLFDSGETAEATRLKRTARRRYTRRKNRILYLQEIFSNEMAKV' | |
| elif st.button('PETase'): | |
| sequence = 'MGSSHHHHHHSSGLVPRGSHMRGPNPTAASLEASAGPFTVRSFTVSRPSGYGAGTVYYPTNAGGTVGAIAIVPGYTARQSSIKWWGPRLASHGFVVITIDTNSTLDQPSSRSSQQMAALRQVASLNGTSSSPIYGKVDTARMGVMGWSMGGGGSLISAANNPSLKAAAPQAPWDSSTNFSSVTVPTLIFACENDSIAPVNSSALPIYDSMSRNAKQFLEINGGSHSCANSGNSNQALIGKKGVAWMKRFMDNDTRYSTFACENPNSTRVSDFRTANCSLEDPAANKARKEAELAAATAEQ' | |
| if sequence: | |
| st.write('you have entered: ',sequence) | |
| res_knn = KNN_search(sequence) | |
| st.subheader('KNN predictor result') | |
| start[2] = st.markdown("Activity prediction: " + str(res_knn)) | |
| elif option == 'PDB viewer': | |
| id_PDB = st.text_input('enter PDB ID', '') | |
| residues_marker = st.text_input('residues class', '') | |
| st.write('Try an example:') | |
| if st.button('PDB ID: 1A2C / residues class: ALA'): | |
| id_PDB = '1A2C' | |
| residues_marker = 'ALA' | |
| st.subheader('PDB viewer') | |
| if residues_marker: | |
| start[7] = showmol(render_pdb_resn(viewer = render_pdb(id = id_PDB),resn_lst = [residues_marker])) | |
| else: | |
| start[7] = showmol(render_pdb(id = id_PDB)) | |
| expander = st.expander("See explanation") | |
| expander.markdown(""" | |
| A PDB ID is a unique 4-character code for each entry in the Protein Data Bank. The first character must be a number between 1 and 9, and the remaining three characters can be letters or numbers. | |
| see https://www.rcsb.org/ for more information. | |
| """) |