root
data cleaning, blast, and splitting code with source data, also deleting unnecessary files
6efd653
import matplotlib.pyplot as plt | |
import numpy as np | |
from scipy.stats import entropy | |
from sklearn.manifold import TSNE | |
import pickle | |
import pandas as pd | |
import os | |
from fuson_plm.utils.logging import log_update | |
from fuson_plm.utils.visualizing import set_font, visualize_splits | |
def main(): | |
set_font() | |
train_clusters = pd.read_csv('splits/train_cluster_split.csv') | |
val_clusters = pd.read_csv('splits/val_cluster_split.csv') | |
test_clusters = pd.read_csv('splits/test_cluster_split.csv') | |
clusters = pd.concat([train_clusters,val_clusters,test_clusters]) | |
fuson_db = pd.read_csv('fuson_db.csv') | |
# Get the sequence IDs of all clustered benchmark sequences. | |
benchmark_seq_ids = fuson_db.loc[fuson_db['benchmark'].notna()]['seq_id'] | |
# Use benchmark_seq_ids to find which clusters contain benchmark sequences. | |
benchmark_cluster_reps = clusters.loc[clusters['member seq_id'].isin(benchmark_seq_ids)]['representative seq_id'].unique().tolist() | |
visualize_splits(train_clusters, val_clusters, test_clusters, benchmark_cluster_reps) | |
## Add seq_id to every source data file that is saved from visualize_splits | |
seq_to_id_dict = dict(zip(fuson_db['aa_seq'],fuson_db['seq_id'])) | |
files_to_edit = os.listdir("splits/split_vis") | |
files_to_edit = [x for x in files_to_edit if x[-4::]==".csv"] | |
log_update(f"Adding seq_ids to the following files: {files_to_edit}") | |
for fname in files_to_edit: | |
source_data_file = pd.read_csv(f"splits/split_vis/{fname}") | |
if "sequence" in list(source_data_file.columns): | |
source_data_file["seq_id"] = source_data_file["sequence"].map(seq_to_id_dict) | |
source_data_file.drop(columns=['sequence']).to_csv(f"splits/split_vis/{fname}",index=False) | |
if __name__ == "__main__": | |
main() |