a96123155 commited on
Commit
82d4030
·
1 Parent(s): ad079a8
Files changed (2) hide show
  1. .DS_Store +0 -0
  2. app.py +51 -52
.DS_Store CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
 
app.py CHANGED
@@ -1,52 +1,81 @@
1
  import streamlit as st
2
 
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  # Import necessary libraries
5
- import argparse
6
- import matplotlib
7
- import matplotlib.pyplot as plt
8
  import numpy as np
9
  import os
10
  import pandas as pd
11
- import pathlib
12
  import random
13
- import scanpy as sc
14
- import seaborn as sns
15
  import torch
16
  import torch.nn as nn
17
  import torch.nn.functional as F
18
- from argparse import Namespace
19
  from collections import Counter, OrderedDict
20
  from copy import deepcopy
21
  from esm import Alphabet, FastaBatchedDataset, ProteinBertModel, pretrained, MSATransformer
22
  from esm.data import *
23
  from esm.model.esm2 import ESM2
24
- from sklearn import preprocessing
25
- from sklearn.metrics import (confusion_matrix, roc_auc_score, auc,
26
- precision_recall_fscore_support,
27
- precision_recall_curve, classification_report,
28
- roc_auc_score, average_precision_score,
29
- precision_score, recall_score, f1_score,
30
- accuracy_score)
31
- from sklearn.model_selection import StratifiedKFold
32
- from sklearn.utils import class_weight
33
- from scipy.stats import spearmanr, pearsonr
34
  from torch import nn
35
  from torch.nn import Linear
36
  from torch.nn.utils.rnn import pad_sequence
37
  from torch.utils.data import Dataset, DataLoader
38
- from torch.optim import lr_scheduler
39
  from tqdm import tqdm, trange
40
 
41
  # Set global variables
42
- matplotlib.rcParams.update({'font.size': 7})
43
  seed = 19961231
44
  random.seed(seed)
45
  np.random.seed(seed)
46
  torch.manual_seed(seed)
47
- torch.cuda.manual_seed(seed)
48
- torch.backends.cudnn.deterministic = True
49
- torch.backends.cudnn.benchmark = False
50
 
51
 
52
  global idx_to_tok, prefix, epochs, layers, heads, fc_node, dropout_prob, embed_dim, batch_toks, device, repr_layers, evaluation, include, truncate, return_contacts, return_representation, mask_toks_id, finetune
@@ -499,36 +528,6 @@ def predict_raw(raw_input):
499
  # print(pred)
500
  return res_pd
501
 
502
-
503
- st.title("IRES-LM prediction and mutation")
504
-
505
- # Input sequence
506
- st.subheader("Input sequence")
507
-
508
- seq = st.text_area("FASTA format only", value="")
509
- st.subheader("Upload sequence file")
510
- uploaded = st.file_uploader("Sequence file in FASTA format")
511
-
512
- # augments
513
- global output_filename, start_nt_position, end_nt_position, mut_by_prob, transform_type, mlm_tok_num, n_mut, n_designs_ep, n_sampling_designs_ep, n_mlm_recovery_sampling, mutate2stronger
514
- output_filename = st.text_input("output a .csv file", value='IRES_LM_prediction_mutation')
515
- start_nt_position = st.number_input("The start position of the mutation of this sequence, the first position is defined as 0", value=0)
516
- end_nt_position = st.number_input("The last position of the mutation of this sequence, the last position is defined as length(sequence)-1 or -1", value=-1)
517
- mut_by_prob = st.checkbox("Mutated by predicted Probability or Transformed Probability of the sequence", value=True)
518
- transform_type = st.selectbox("Type of probability transformation",
519
- ['', 'sigmoid', 'logit', 'power_law', 'tanh'],
520
- index=2)
521
- mlm_tok_num = st.number_input("Number of masked tokens for each sequence per epoch", value=1)
522
- n_mut = st.number_input("Maximum number of mutations for each sequence", value=3)
523
- n_designs_ep = st.number_input("Number of mutations per epoch", value=10)
524
- n_sampling_designs_ep = st.number_input("Number of sampling mutations from n_designs_ep per epoch", value=5)
525
- n_mlm_recovery_sampling = st.number_input("Number of MLM recovery samplings (with AGCT recovery)", value=1)
526
- mutate2stronger = st.checkbox("Mutate to stronger IRES variant, otherwise mutate to weaker IRES", value=True)
527
-
528
- if not mut_by_prob and transform_type != '':
529
- print("--transform_type must be '' when --mut_by_prob is False")
530
- transform_type = ''
531
-
532
  # Run
533
  if st.button("Predict and Mutate"):
534
  if uploaded:
 
1
  import streamlit as st
2
 
3
 
4
+
5
+ st.title("IRES-LM prediction and mutation")
6
+
7
+ # Input sequence
8
+ st.subheader("Input sequence")
9
+
10
+ seq = st.text_area("FASTA format only", value="")
11
+ st.subheader("Upload sequence file")
12
+ uploaded = st.file_uploader("Sequence file in FASTA format")
13
+
14
+ # augments
15
+ global output_filename, start_nt_position, end_nt_position, mut_by_prob, transform_type, mlm_tok_num, n_mut, n_designs_ep, n_sampling_designs_ep, n_mlm_recovery_sampling, mutate2stronger
16
+ output_filename = st.text_input("output a .csv file", value='IRES_LM_prediction_mutation')
17
+ start_nt_position = st.number_input("The start position of the mutation of this sequence, the first position is defined as 0", value=0)
18
+ end_nt_position = st.number_input("The last position of the mutation of this sequence, the last position is defined as length(sequence)-1 or -1", value=-1)
19
+ mut_by_prob = st.checkbox("Mutated by predicted Probability or Transformed Probability of the sequence", value=True)
20
+ transform_type = st.selectbox("Type of probability transformation",
21
+ ['', 'sigmoid', 'logit', 'power_law', 'tanh'],
22
+ index=2)
23
+ mlm_tok_num = st.number_input("Number of masked tokens for each sequence per epoch", value=1)
24
+ n_mut = st.number_input("Maximum number of mutations for each sequence", value=3)
25
+ n_designs_ep = st.number_input("Number of mutations per epoch", value=10)
26
+ n_sampling_designs_ep = st.number_input("Number of sampling mutations from n_designs_ep per epoch", value=5)
27
+ n_mlm_recovery_sampling = st.number_input("Number of MLM recovery samplings (with AGCT recovery)", value=1)
28
+ mutate2stronger = st.checkbox("Mutate to stronger IRES variant, otherwise mutate to weaker IRES", value=True)
29
+
30
+ if not mut_by_prob and transform_type != '':
31
+ print("--transform_type must be '' when --mut_by_prob is False")
32
+ transform_type = ''
33
+
34
+
35
  # Import necessary libraries
36
+ # import matplotlib
37
+ # import matplotlib.pyplot as plt
 
38
  import numpy as np
39
  import os
40
  import pandas as pd
41
+ # import pathlib
42
  import random
43
+ # import scanpy as sc
44
+ # import seaborn as sns
45
  import torch
46
  import torch.nn as nn
47
  import torch.nn.functional as F
48
+ # from argparse import Namespace
49
  from collections import Counter, OrderedDict
50
  from copy import deepcopy
51
  from esm import Alphabet, FastaBatchedDataset, ProteinBertModel, pretrained, MSATransformer
52
  from esm.data import *
53
  from esm.model.esm2 import ESM2
54
+ # from sklearn import preprocessing
55
+ # from sklearn.metrics import (confusion_matrix, roc_auc_score, auc,
56
+ # precision_recall_fscore_support,
57
+ # precision_recall_curve, classification_report,
58
+ # roc_auc_score, average_precision_score,
59
+ # precision_score, recall_score, f1_score,
60
+ # accuracy_score)
61
+ # from sklearn.model_selection import StratifiedKFold
62
+ # from sklearn.utils import class_weight
63
+ # from scipy.stats import spearmanr, pearsonr
64
  from torch import nn
65
  from torch.nn import Linear
66
  from torch.nn.utils.rnn import pad_sequence
67
  from torch.utils.data import Dataset, DataLoader
 
68
  from tqdm import tqdm, trange
69
 
70
  # Set global variables
71
+ # matplotlib.rcParams.update({'font.size': 7})
72
  seed = 19961231
73
  random.seed(seed)
74
  np.random.seed(seed)
75
  torch.manual_seed(seed)
76
+ # torch.cuda.manual_seed(seed)
77
+ # torch.backends.cudnn.deterministic = True
78
+ # torch.backends.cudnn.benchmark = False
79
 
80
 
81
  global idx_to_tok, prefix, epochs, layers, heads, fc_node, dropout_prob, embed_dim, batch_toks, device, repr_layers, evaluation, include, truncate, return_contacts, return_representation, mask_toks_id, finetune
 
528
  # print(pred)
529
  return res_pd
530
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
531
  # Run
532
  if st.button("Predict and Mutate"):
533
  if uploaded: