Spaces:
Running
Running
a96123155
commited on
Commit
·
82d4030
1
Parent(s):
ad079a8
app
Browse files
.DS_Store
CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
|
|
app.py
CHANGED
@@ -1,52 +1,81 @@
|
|
1 |
import streamlit as st
|
2 |
|
3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
# Import necessary libraries
|
5 |
-
import
|
6 |
-
import matplotlib
|
7 |
-
import matplotlib.pyplot as plt
|
8 |
import numpy as np
|
9 |
import os
|
10 |
import pandas as pd
|
11 |
-
import pathlib
|
12 |
import random
|
13 |
-
import scanpy as sc
|
14 |
-
import seaborn as sns
|
15 |
import torch
|
16 |
import torch.nn as nn
|
17 |
import torch.nn.functional as F
|
18 |
-
from argparse import Namespace
|
19 |
from collections import Counter, OrderedDict
|
20 |
from copy import deepcopy
|
21 |
from esm import Alphabet, FastaBatchedDataset, ProteinBertModel, pretrained, MSATransformer
|
22 |
from esm.data import *
|
23 |
from esm.model.esm2 import ESM2
|
24 |
-
from sklearn import preprocessing
|
25 |
-
from sklearn.metrics import (confusion_matrix, roc_auc_score, auc,
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
from sklearn.model_selection import StratifiedKFold
|
32 |
-
from sklearn.utils import class_weight
|
33 |
-
from scipy.stats import spearmanr, pearsonr
|
34 |
from torch import nn
|
35 |
from torch.nn import Linear
|
36 |
from torch.nn.utils.rnn import pad_sequence
|
37 |
from torch.utils.data import Dataset, DataLoader
|
38 |
-
from torch.optim import lr_scheduler
|
39 |
from tqdm import tqdm, trange
|
40 |
|
41 |
# Set global variables
|
42 |
-
matplotlib.rcParams.update({'font.size': 7})
|
43 |
seed = 19961231
|
44 |
random.seed(seed)
|
45 |
np.random.seed(seed)
|
46 |
torch.manual_seed(seed)
|
47 |
-
torch.cuda.manual_seed(seed)
|
48 |
-
torch.backends.cudnn.deterministic = True
|
49 |
-
torch.backends.cudnn.benchmark = False
|
50 |
|
51 |
|
52 |
global idx_to_tok, prefix, epochs, layers, heads, fc_node, dropout_prob, embed_dim, batch_toks, device, repr_layers, evaluation, include, truncate, return_contacts, return_representation, mask_toks_id, finetune
|
@@ -499,36 +528,6 @@ def predict_raw(raw_input):
|
|
499 |
# print(pred)
|
500 |
return res_pd
|
501 |
|
502 |
-
|
503 |
-
st.title("IRES-LM prediction and mutation")
|
504 |
-
|
505 |
-
# Input sequence
|
506 |
-
st.subheader("Input sequence")
|
507 |
-
|
508 |
-
seq = st.text_area("FASTA format only", value="")
|
509 |
-
st.subheader("Upload sequence file")
|
510 |
-
uploaded = st.file_uploader("Sequence file in FASTA format")
|
511 |
-
|
512 |
-
# augments
|
513 |
-
global output_filename, start_nt_position, end_nt_position, mut_by_prob, transform_type, mlm_tok_num, n_mut, n_designs_ep, n_sampling_designs_ep, n_mlm_recovery_sampling, mutate2stronger
|
514 |
-
output_filename = st.text_input("output a .csv file", value='IRES_LM_prediction_mutation')
|
515 |
-
start_nt_position = st.number_input("The start position of the mutation of this sequence, the first position is defined as 0", value=0)
|
516 |
-
end_nt_position = st.number_input("The last position of the mutation of this sequence, the last position is defined as length(sequence)-1 or -1", value=-1)
|
517 |
-
mut_by_prob = st.checkbox("Mutated by predicted Probability or Transformed Probability of the sequence", value=True)
|
518 |
-
transform_type = st.selectbox("Type of probability transformation",
|
519 |
-
['', 'sigmoid', 'logit', 'power_law', 'tanh'],
|
520 |
-
index=2)
|
521 |
-
mlm_tok_num = st.number_input("Number of masked tokens for each sequence per epoch", value=1)
|
522 |
-
n_mut = st.number_input("Maximum number of mutations for each sequence", value=3)
|
523 |
-
n_designs_ep = st.number_input("Number of mutations per epoch", value=10)
|
524 |
-
n_sampling_designs_ep = st.number_input("Number of sampling mutations from n_designs_ep per epoch", value=5)
|
525 |
-
n_mlm_recovery_sampling = st.number_input("Number of MLM recovery samplings (with AGCT recovery)", value=1)
|
526 |
-
mutate2stronger = st.checkbox("Mutate to stronger IRES variant, otherwise mutate to weaker IRES", value=True)
|
527 |
-
|
528 |
-
if not mut_by_prob and transform_type != '':
|
529 |
-
print("--transform_type must be '' when --mut_by_prob is False")
|
530 |
-
transform_type = ''
|
531 |
-
|
532 |
# Run
|
533 |
if st.button("Predict and Mutate"):
|
534 |
if uploaded:
|
|
|
1 |
import streamlit as st
|
2 |
|
3 |
|
4 |
+
|
5 |
+
st.title("IRES-LM prediction and mutation")
|
6 |
+
|
7 |
+
# Input sequence
|
8 |
+
st.subheader("Input sequence")
|
9 |
+
|
10 |
+
seq = st.text_area("FASTA format only", value="")
|
11 |
+
st.subheader("Upload sequence file")
|
12 |
+
uploaded = st.file_uploader("Sequence file in FASTA format")
|
13 |
+
|
14 |
+
# augments
|
15 |
+
global output_filename, start_nt_position, end_nt_position, mut_by_prob, transform_type, mlm_tok_num, n_mut, n_designs_ep, n_sampling_designs_ep, n_mlm_recovery_sampling, mutate2stronger
|
16 |
+
output_filename = st.text_input("output a .csv file", value='IRES_LM_prediction_mutation')
|
17 |
+
start_nt_position = st.number_input("The start position of the mutation of this sequence, the first position is defined as 0", value=0)
|
18 |
+
end_nt_position = st.number_input("The last position of the mutation of this sequence, the last position is defined as length(sequence)-1 or -1", value=-1)
|
19 |
+
mut_by_prob = st.checkbox("Mutated by predicted Probability or Transformed Probability of the sequence", value=True)
|
20 |
+
transform_type = st.selectbox("Type of probability transformation",
|
21 |
+
['', 'sigmoid', 'logit', 'power_law', 'tanh'],
|
22 |
+
index=2)
|
23 |
+
mlm_tok_num = st.number_input("Number of masked tokens for each sequence per epoch", value=1)
|
24 |
+
n_mut = st.number_input("Maximum number of mutations for each sequence", value=3)
|
25 |
+
n_designs_ep = st.number_input("Number of mutations per epoch", value=10)
|
26 |
+
n_sampling_designs_ep = st.number_input("Number of sampling mutations from n_designs_ep per epoch", value=5)
|
27 |
+
n_mlm_recovery_sampling = st.number_input("Number of MLM recovery samplings (with AGCT recovery)", value=1)
|
28 |
+
mutate2stronger = st.checkbox("Mutate to stronger IRES variant, otherwise mutate to weaker IRES", value=True)
|
29 |
+
|
30 |
+
if not mut_by_prob and transform_type != '':
|
31 |
+
print("--transform_type must be '' when --mut_by_prob is False")
|
32 |
+
transform_type = ''
|
33 |
+
|
34 |
+
|
35 |
# Import necessary libraries
|
36 |
+
# import matplotlib
|
37 |
+
# import matplotlib.pyplot as plt
|
|
|
38 |
import numpy as np
|
39 |
import os
|
40 |
import pandas as pd
|
41 |
+
# import pathlib
|
42 |
import random
|
43 |
+
# import scanpy as sc
|
44 |
+
# import seaborn as sns
|
45 |
import torch
|
46 |
import torch.nn as nn
|
47 |
import torch.nn.functional as F
|
48 |
+
# from argparse import Namespace
|
49 |
from collections import Counter, OrderedDict
|
50 |
from copy import deepcopy
|
51 |
from esm import Alphabet, FastaBatchedDataset, ProteinBertModel, pretrained, MSATransformer
|
52 |
from esm.data import *
|
53 |
from esm.model.esm2 import ESM2
|
54 |
+
# from sklearn import preprocessing
|
55 |
+
# from sklearn.metrics import (confusion_matrix, roc_auc_score, auc,
|
56 |
+
# precision_recall_fscore_support,
|
57 |
+
# precision_recall_curve, classification_report,
|
58 |
+
# roc_auc_score, average_precision_score,
|
59 |
+
# precision_score, recall_score, f1_score,
|
60 |
+
# accuracy_score)
|
61 |
+
# from sklearn.model_selection import StratifiedKFold
|
62 |
+
# from sklearn.utils import class_weight
|
63 |
+
# from scipy.stats import spearmanr, pearsonr
|
64 |
from torch import nn
|
65 |
from torch.nn import Linear
|
66 |
from torch.nn.utils.rnn import pad_sequence
|
67 |
from torch.utils.data import Dataset, DataLoader
|
|
|
68 |
from tqdm import tqdm, trange
|
69 |
|
70 |
# Set global variables
|
71 |
+
# matplotlib.rcParams.update({'font.size': 7})
|
72 |
seed = 19961231
|
73 |
random.seed(seed)
|
74 |
np.random.seed(seed)
|
75 |
torch.manual_seed(seed)
|
76 |
+
# torch.cuda.manual_seed(seed)
|
77 |
+
# torch.backends.cudnn.deterministic = True
|
78 |
+
# torch.backends.cudnn.benchmark = False
|
79 |
|
80 |
|
81 |
global idx_to_tok, prefix, epochs, layers, heads, fc_node, dropout_prob, embed_dim, batch_toks, device, repr_layers, evaluation, include, truncate, return_contacts, return_representation, mask_toks_id, finetune
|
|
|
528 |
# print(pred)
|
529 |
return res_pd
|
530 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
531 |
# Run
|
532 |
if st.button("Predict and Mutate"):
|
533 |
if uploaded:
|