import streamlit as st from io import StringIO st.title("IRES-LM prediction and mutation") # Input sequence st.subheader("Input sequence") seq = st.text_area("FASTA format only", value=">vir_CVB3_ires_00505.1\nTTAAAACAGCCTGTGGGTTGATCCCACCCACAGGCCCATTGGGCGCTAGCACTCTGGTATCACGGTACCTTTGTGCGCCTGTTTTATACCCCCTCCCCCAACTGTAACTTAGAAGTAACACACACCGATCAACAGTCAGCGTGGCACACCAGCCACGTTTTGATCAAGCACTTCTGTTACCCCGGACTGAGTATCAATAGACTGCTCACGCGGTTGAAGGAGAAAGCGTTCGTTATCCGGCCAACTACTTCGAAAAACCTAGTAACACCGTGGAAGTTGCAGAGTGTTTCGCTCAGCACTACCCCAGTGTAGATCAGGTCGATGAGTCACCGCATTCCCCACGGGCGACCGTGGCGGTGGCTGCGTTGGCGGCCTGCCCATGGGGAAACCCATGGGACGCTCTAATACAGACATGGTGCGAAGAGTCTATTGAGCTAGTTGGTAGTCCTCCGGCCCCTGAATGCGGCTAATCCTAACTGCGGAGCACACACCCTCAAGCCAGAGGGCAGTGTGTCGTAACGGGCAACTCTGCAGCGGAACCGACTACTTTGGGTGTCCGTGTTTCATTTTATTCCTATACTGGCTGCTTATGGTGACAATTGAGAGATCGTTACCATATAGCTATTGGATTGGCCATCCGGTGACTAATAGAGCTATTATATATCCCTTTGTTGGGTTTATACCACTTAGCTTGAAAGAGGTTAAAACATTACAATTCATTGTTAAGTTGAATACAGCAAA") st.subheader("Upload sequence file") uploaded = st.file_uploader("Sequence file in FASTA format") # augments global output_filename, start_nt_position, end_nt_position, mut_by_prob, transform_type, mlm_tok_num, n_mut, n_designs_ep, n_sampling_designs_ep, n_mlm_recovery_sampling, mutate2stronger output_filename = st.text_input("output a .csv file", value='IRES_LM_prediction_mutation') start_nt_position = st.number_input("The start position of the mutation of this sequence, the first position is defined as 0", value=0) end_nt_position = st.number_input("The last position of the mutation of this sequence, the last position is defined as length(sequence)-1 or -1", value=-1) mut_by_prob = st.checkbox("Mutated by predicted Probability or Transformed Probability of the sequence", value=True) transform_type = st.selectbox("Type of probability transformation", ['', 'sigmoid', 'logit', 'power_law', 'tanh'], index=2) mlm_tok_num = st.number_input("Number of masked tokens for each sequence per epoch", value=1) n_mut = st.number_input("Maximum number of mutations for each sequence", value=3) n_designs_ep = st.number_input("Number of mutations per epoch", value=10) n_sampling_designs_ep = st.number_input("Number of sampling mutations from n_designs_ep per epoch", value=5) n_mlm_recovery_sampling = st.number_input("Number of MLM recovery samplings (with AGCT recovery)", value=1) mutate2stronger = st.checkbox("Mutate to stronger IRES variant, otherwise mutate to weaker IRES", value=True) if not mut_by_prob and transform_type != '': print("--transform_type must be '' when --mut_by_prob is False") transform_type = '' # Import necessary libraries # import matplotlib # import matplotlib.pyplot as plt import numpy as np import os import pandas as pd # import pathlib import random # import scanpy as sc # import seaborn as sns import torch import torch.nn as nn import torch.nn.functional as F # from argparse import Namespace from collections import Counter, OrderedDict from copy import deepcopy from esm import Alphabet, FastaBatchedDataset, ProteinBertModel, pretrained, MSATransformer from esm.data import * from esm.model.esm2 import ESM2 # from sklearn import preprocessing # from sklearn.metrics import (confusion_matrix, roc_auc_score, auc, # precision_recall_fscore_support, # precision_recall_curve, classification_report, # roc_auc_score, average_precision_score, # precision_score, recall_score, f1_score, # accuracy_score) # from sklearn.model_selection import StratifiedKFold # from sklearn.utils import class_weight # from scipy.stats import spearmanr, pearsonr from torch import nn from torch.nn import Linear from torch.nn.utils.rnn import pad_sequence from torch.utils.data import Dataset, DataLoader from tqdm import tqdm, trange # Set global variables # matplotlib.rcParams.update({'font.size': 7}) seed = 19961231 random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) # torch.cuda.manual_seed(seed) # torch.backends.cudnn.deterministic = True # torch.backends.cudnn.benchmark = False global idx_to_tok, prefix, epochs, layers, heads, fc_node, dropout_prob, embed_dim, batch_toks, device, repr_layers, evaluation, include, truncate, return_contacts, return_representation, mask_toks_id, finetune epochs = 5 layers = 6 heads = 16 embed_dim = 128 batch_toks = 4096 fc_node = 64 dropout_prob = 0.5 folds = 10 repr_layers = [-1] include = ["mean"] truncate = True finetune = False return_contacts = False return_representation = False device = "cpu" global tok_to_idx, idx_to_tok, mask_toks_id alphabet = Alphabet(mask_prob = 0.15, standard_toks = 'AGCT') assert alphabet.tok_to_idx == {'': 0, '': 1, '': 2, 'A': 3, 'G': 4, 'C': 5, 'T': 6, '': 7, '': 8, '': 9} # tok_to_idx = {'': 0, '': 1, '': 2, 'A': 3, 'G': 4, 'C': 5, 'T': 6, '': 7, '': 8, '': 9} tok_to_idx = {'-': 0, '&': 1, '?': 2, 'A': 3, 'G': 4, 'C': 5, 'T': 6, '!': 7, '*': 8, '|': 9} idx_to_tok = {idx: tok for tok, idx in tok_to_idx.items()} print(tok_to_idx) mask_toks_id = 8 global w1, w2, w3 w1, w2, w3 = 1, 1, 100 class CNN_linear(nn.Module): def __init__(self): super(CNN_linear, self).__init__() self.esm2 = ESM2(num_layers = layers, embed_dim = embed_dim, attention_heads = heads, alphabet = alphabet) self.dropout = nn.Dropout(dropout_prob) self.relu = nn.ReLU() self.flatten = nn.Flatten() self.fc = nn.Linear(in_features = embed_dim, out_features = fc_node) self.output = nn.Linear(in_features = fc_node, out_features = 2) def predict(self, tokens): x = self.esm2(tokens, [layers], need_head_weights=False, return_contacts=False, return_representation = True) x_cls = x["representations"][layers][:, 0] o = self.fc(x_cls) o = self.relu(o) o = self.dropout(o) o = self.output(o) y_prob = torch.softmax(o, dim = 1) y_pred = torch.argmax(y_prob, dim = 1) if transform_type: y_prob_transformed = prob_transform(y_prob[:,1]) return y_prob[:,1], y_pred, x['logits'], y_prob_transformed else: return y_prob[:,1], y_pred, x['logits'], o[:,1] def forward(self, x1, x2): logit_1, repr_1 = self.predict(x1) logit_2, repr_2 = self.predict(x2) return (logit_1, logit_2), (repr_1, repr_2) def prob_transform(prob, **kwargs): # Logits """ Transforms probability values based on the specified method. :param prob: torch.Tensor, the input probabilities to be transformed :param transform_type: str, the type of transformation to be applied :param kwargs: additional parameters for transformations :return: torch.Tensor, transformed probabilities """ if transform_type == 'sigmoid': x0 = kwget('x0', 0.5) k = kwget('k', 10.0) prob_transformed = 1 / (1 + torch.exp(-k * (prob - x0))) elif transform_type == 'logit': # Adding a small value to avoid log(0) and log(1) prob_transformed = torch.log(prob + 1e-6) - torch.log(1 - prob + 1e-6) elif transform_type == 'power_law': gamma = kwget('gamma', 2.0) prob_transformed = torch.pow(prob, gamma) elif transform_type == 'tanh': k = kwget('k', 2.0) prob_transformed = torch.tanh(k * prob) return prob_transformed def random_replace(sequence, continuous_replace=False): if end_nt_position == -1: end_nt_position = len(sequence) if start_nt_position < 0 or end_nt_position > len(sequence) or start_nt_position > end_nt_position: # raise ValueError("Invalid start/end positions") print("Invalid start/end positions") start_nt_position, end_nt_position = 0, -1 # 将序列切片成三部分:替换区域前、替换区域、替换区域后 pre_segment = sequence[:start_nt_position] target_segment = list(sequence[start_nt_position:end_nt_position + 1]) # +1因为Python的切片是右开区间 post_segment = sequence[end_nt_position + 1:] if not continuous_replace: # 随机替换目标片段的mlm_tok_num个位置 indices = random.sample(range(len(target_segment)), mlm_tok_num) for idx in indices: target_segment[idx] = '*' else: # 在目标片段连续替换mlm_tok_num个位置 max_start_idx = len(target_segment) - mlm_tok_num # 确保从i开始的n_mut个元素不会超出目标片段的长度 if max_start_idx < 1: # 如果目标片段长度小于mlm_tok_num,返回原始序列 return target_segment start_idx = random.randint(0, max_start_idx) for idx in range(start_idx, start_idx + mlm_tok_num): target_segment[idx] = '*' # 合并并返回最终的序列 return ''.join([pre_segment] + target_segment + [post_segment]) def mlm_seq(seq): seq_token, masked_sequence_token = [7],[7] seq_token += [tok_to_idx[token] for token in seq] masked_seq = random_replace(seq, n_mut) # 随机替换n_mut个元素为'*' masked_seq_token += [tok_to_idx[token] for token in masked_seq] return seq, masked_seq, torch.LongTensor(seq_token), torch.LongTensor(masked_seq_token) def batch_mlm_seq(seq_list, continuous_replace = False): batch_seq = [] batch_masked_seq = [] batch_seq_token_list = [] batch_masked_seq_token_list = [] for i, seq in enumerate(seq_list): seq_token, masked_seq_token = [7], [7] seq_token += [tok_to_idx[token] for token in seq] masked_seq = random_replace(seq, continuous_replace) # 随机替换n_mut个元素为'*' masked_seq_token += [tok_to_idx[token] for token in masked_seq] batch_seq.append(seq) batch_masked_seq.append(masked_seq) batch_seq_token_list.append(seq_token) batch_masked_seq_token_list.append(masked_seq_token) return batch_seq, batch_masked_seq, torch.LongTensor(batch_seq_token_list), torch.LongTensor(batch_masked_seq_token_list) def recovered_mlm_tokens(masked_seqs, masked_toks, esm_logits, exclude_low_prob = False): # Only remain the AGCT logits esm_logits = esm_logits[:,:,3:7] # Get the predicted tokens using argmax predicted_toks = (esm_logits.argmax(dim=-1)+3).tolist() batch_size, seq_len, vocab_size = esm_logits.size() if exclude_low_prob: min_prob = 1 / vocab_size # Initialize an empty list to store the recovered sequences recovered_sequences, recovered_toks = [], [] for i in range(batch_size): recovered_sequence_i, recovered_tok_i = [], [] for j in range(seq_len): if masked_toks[i][j] == 8: print(i,j) ### Sample M recovery sequences using the logits recovery_probs = torch.softmax(esm_logits[i, j], dim=-1) recovery_probs[predicted_toks[i][j]-3] = 0 # Exclude the most probable token if exclude_low_prob: recovery_probs[recovery_probs < min_prob] = 0 # Exclude tokens with low probs < min_prob recovery_probs /= recovery_probs.sum() # Normalize the probabilities ### 有放回抽样 max_retries = 5 retries = 0 success = False while retries < max_retries and not success: try: recovery_indices = list(np.random.choice(vocab_size, size=n_mlm_recovery_sampling, p=recovery_probs.cpu().detach().numpy(), replace=False)) success = True # 设置成功标志 except ValueError as e: retries += 1 print(f"Attempt {retries} failed with error: {e}") if retries >= max_retries: print("Max retries reached. Skipping this iteration.") ### recovery to sequence if retries < max_retries: for idx in [predicted_toks[i][j]] + [3+i for i in recovery_indices]: recovery_seq = deepcopy(list(masked_seqs[i])) recovery_tok = deepcopy(masked_toks[i]) recovery_tok[j] = idx recovery_seq[j-1] = idx_to_tok[idx] recovered_tok_i.append(recovery_tok) recovered_sequence_i.append(''.join(recovery_seq)) recovered_sequences.extend(recovered_sequence_i) recovered_toks.extend(recovered_tok_i) return recovered_sequences, torch.LongTensor(torch.stack(recovered_toks)) def recovered_mlm_multi_tokens(masked_seqs, masked_toks, esm_logits, exclude_low_prob = False): # Only remain the AGCT logits esm_logits = esm_logits[:,:,3:7] # Get the predicted tokens using argmax predicted_toks = (esm_logits.argmax(dim=-1)+3).tolist() batch_size, seq_len, vocab_size = esm_logits.size() if exclude_low_prob: min_prob = 1 / vocab_size # Initialize an empty list to store the recovered sequences recovered_sequences, recovered_toks = [], [] for i in range(batch_size): recovered_sequence_i, recovered_tok_i = [], [] recovered_masked_num = 0 for j in range(seq_len): if masked_toks[i][j] == 8: ### Sample M recovery sequences using the logits recovery_probs = torch.softmax(esm_logits[i, j], dim=-1) recovery_probs[predicted_toks[i][j]-3] = 0 # Exclude the most probable token if exclude_low_prob: recovery_probs[recovery_probs < min_prob] = 0 # Exclude tokens with low probs < min_prob recovery_probs /= recovery_probs.sum() # Normalize the probabilities ### 有放回抽样 max_retries = 5 retries = 0 success = False while retries < max_retries and not success: try: recovery_indices = list(np.random.choice(vocab_size, size=n_mlm_recovery_sampling, p=recovery_probs.cpu().detach().numpy(), replace=False)) success = True # 设置成功标志 except ValueError as e: retries += 1 print(f"Attempt {retries} failed with error: {e}") if retries >= max_retries: print("Max retries reached. Skipping this iteration.") ### recovery to sequence if recovered_masked_num == 0: if retries < max_retries: for idx in [predicted_toks[i][j]] + [3+i for i in recovery_indices]: recovery_seq = deepcopy(list(masked_seqs[i])) recovery_tok = deepcopy(masked_toks[i]) recovery_tok[j] = idx recovery_seq[j-1] = idx_to_tok[idx] recovered_tok_i.append(recovery_tok) recovered_sequence_i.append(''.join(recovery_seq)) elif recovered_masked_num > 0: if retries < max_retries: for idx in [predicted_toks[i][j]] + [3+i for i in recovery_indices]: for recovery_seq, recovery_tok in zip(list(recovered_sequence_i), list(recovered_tok_i)): # 要在循环开始之前获取列表的副本来进行迭代。这样,在循环中即使我们修改了原始的列表,也不会影响迭代的行为。 recovery_seq_temp = list(recovery_seq) recovery_tok[j] = idx recovery_seq_temp[j-1] = idx_to_tok[idx] recovered_tok_i.append(recovery_tok) recovered_sequence_i.append(''.join(recovery_seq_temp)) recovered_masked_num += 1 recovered_indices = [i for i, s in enumerate(recovered_sequence_i) if '*' not in s] recovered_tok_i = [recovered_tok_i[i] for i in recovered_indices] recovered_sequence_i = [recovered_sequence_i[i] for i in recovered_indices] recovered_sequences.extend(recovered_sequence_i) recovered_toks.extend(recovered_tok_i) recovered_sequences, recovered_toks = remove_duplicates_double(recovered_sequences, recovered_toks) return recovered_sequences, torch.LongTensor(torch.stack(recovered_toks)) def mismatched_positions(s1, s2): # 这个函数假定两个字符串的长度相同。 """Return the number of positions where two strings differ.""" # The number of mismatches will be the sum of positions where characters are not the same return sum(1 for c1, c2 in zip(s1, s2) if c1 != c2) def remove_duplicates_triple(filtered_mut_seqs, filtered_mut_probs, filtered_mut_logits): seen = {} unique_seqs = [] unique_probs = [] unique_logits = [] for seq, prob, logit in zip(filtered_mut_seqs, filtered_mut_probs, filtered_mut_logits): if seq not in seen: unique_seqs.append(seq) unique_probs.append(prob) unique_logits.append(logit) seen[seq] = True return unique_seqs, unique_probs, unique_logits def remove_duplicates_double(filtered_mut_seqs, filtered_mut_probs): seen = {} unique_seqs = [] unique_probs = [] for seq, prob in zip(filtered_mut_seqs, filtered_mut_probs): if seq not in seen: unique_seqs.append(seq) unique_probs.append(prob) seen[seq] = True return unique_seqs, unique_probs def mutated_seq(wt_seq, wt_label): wt_seq = '!'+ wt_seq wt_tok = torch.LongTensor([[tok_to_idx[token] for token in wt_seq]]).to(device) wt_prob, wt_pred, _, wt_logit = model.predict(wt_tok) print(f'Wild Type: Length = ', len(wt_seq), '\n', wt_seq) print(f'Wild Type: Label = {wt_label}, Y_pred = {wt_pred.item()}, Y_prob = {wt_prob.item():.2%}') # print(n_mut, mlm_tok_num, n_designs_ep, n_sampling_designs_ep, n_mlm_recovery_sampling, mutate2stronger) # pbar = tqdm(total=n_mut) mutated_seqs = [] i = 1 pbar = st.progress(i, text="mutated number of sequence") while i <= n_mut: if i == 1: seeds_ep = [wt_seq[1:]] seeds_next_ep, seeds_probs_next_ep, seeds_logits_next_ep = [], [], [] for seed in seeds_ep: seed_seq, masked_seed_seq, seed_seq_token, masked_seed_seq_token = batch_mlm_seq([seed] * n_designs_ep, continuous_replace = True) ### mask seed with 1 site to "*" seed_prob, seed_pred, _, seed_logit = model.predict(seed_seq_token[0].unsqueeze_(0).to(device)) _, _, seed_esm_logit, _ = model.predict(masked_seed_seq_token.to(device)) mut_seqs, mut_toks = recovered_mlm_multi_tokens(masked_seed_seq, masked_seed_seq_token, seed_esm_logit) mut_probs, mut_preds, mut_esm_logits, mut_logits = model.predict(mut_toks.to(device)) ### Filter mut_seqs that mut_prob < seed_prob and mut_prob < wild_prob filtered_mut_seqs = [] filtered_mut_probs = [] filtered_mut_logits = [] if mut_by_prob: for z in range(len(mut_seqs)): if mutate2stronger: if mut_probs[z] >= seed_prob and mut_probs[z] >= wt_prob: filtered_mut_seqs.append(mut_seqs[z]) filtered_mut_probs.append(mut_probs[z].cpu().detach().numpy()) filtered_mut_logits.append(mut_logits[z].cpu().detach().numpy()) else: if mut_probs[z] < seed_prob and mut_probs[z] < wt_prob: filtered_mut_seqs.append(mut_seqs[z]) filtered_mut_probs.append(mut_probs[z].cpu().detach().numpy()) filtered_mut_logits.append(mut_logits[z].cpu().detach().numpy()) else: for z in range(len(mut_seqs)): if mutate2stronger: if mut_logits[z] >= seed_logit and mut_logits[z] >= wt_logit: filtered_mut_seqs.append(mut_seqs[z]) filtered_mut_probs.append(mut_probs[z].cpu().detach().numpy()) filtered_mut_logits.append(mut_logits[z].cpu().detach().numpy()) else: if mut_logits[z] < seed_logit and mut_logits[z] < wt_logit: filtered_mut_seqs.append(mut_seqs[z]) filtered_mut_probs.append(mut_probs[z].cpu().detach().numpy()) filtered_mut_logits.append(mut_logits[z].cpu().detach().numpy()) ### Save seeds_next_ep.extend(filtered_mut_seqs) seeds_probs_next_ep.extend(filtered_mut_probs) seeds_logits_next_ep.extend(filtered_mut_logits) seeds_next_ep, seeds_probs_next_ep, seeds_logits_next_ep = remove_duplicates_triple(seeds_next_ep, seeds_probs_next_ep, seeds_logits_next_ep) ### Sampling based on prob if len(seeds_next_ep) > n_sampling_designs_ep: seeds_probs_next_ep_norm = seeds_probs_next_ep / sum(seeds_probs_next_ep) # Normalize the probabilities seeds_index_next_ep = np.random.choice(len(seeds_next_ep), n_sampling_designs_ep, p = seeds_probs_next_ep_norm, replace = False) seeds_next_ep = np.array(seeds_next_ep)[seeds_index_next_ep] seeds_probs_next_ep = np.array(seeds_probs_next_ep)[seeds_index_next_ep] seeds_logits_next_ep = np.array(seeds_logits_next_ep)[seeds_index_next_ep] seeds_mutated_num_next_ep = [mismatched_positions(wt_seq[1:], s) for s in seeds_next_ep] mutated_seqs.extend(list(zip(seeds_next_ep, seeds_logits_next_ep, seeds_probs_next_ep, seeds_mutated_num_next_ep))) seeds_ep = seeds_next_ep i += 1 # pbar.update(1) pbar.progress(i/n_mut, text="Mutating") # pbar.close() st.success('Done', icon="✅") mutated_seqs.extend([(wt_seq[1:], wt_logit.item(), wt_prob.item(), 0)]) mutated_seqs = sorted(mutated_seqs, key=lambda x: x[2], reverse=True) mutated_seqs = pd.DataFrame(mutated_seqs, columns = ['mutated_seq', 'predicted_logit', 'predicted_probability', 'mutated_num']).drop_duplicates('mutated_seq') return mutated_seqs def read_raw(raw_input): ids = [] sequences = [] file = StringIO(raw_input) for record in SeqIO.parse(file, "fasta"): # 检查序列是否只包含A, G, C, T sequence = str(record.seq.back_transcribe()).upper()[-inp_len:] if not set(sequence).issubset(set("AGCT")): st.write(f"Record '{record.description}' was skipped for containing invalid characters. Only A, G, C, T(U) are allowed.") continue # 将符合条件的序列添加到列表中 ids.append(record.id) sequences.append(sequence) return ids, sequences def predict_raw(raw_input): state_dict = torch.load('model.pt', map_location=torch.device(device)) new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k.replace('module.','') new_state_dict[name] = v model = CNN_linear().to(device) model.load_state_dict(new_state_dict, strict = False) model.eval() # st.write('====Parse Input====') ids, seqs = read_raw(raw_input) # st.write('====Predict====') res_pd = pd.DataFrame() for wt_seq, wt_id in zip(seqs, ids): try: res = mutated_seq(wt_seq, wt_id) res_pd.append(res) except: st.write('====Please Try Again this sequence: ', wt_id, wt_seq) # print(pred) return res_pd # Run if st.button("Predict and Mutate"): if uploaded: result = predict_raw(uploaded.getvalue().decode()) else: result = predict_raw(seq) result_file = result.to_csv(index=False) st.download_button("Download", result_file, file_name=output_filename+".csv") st.dataframe(result)