import matplotlib.pyplot as plt import numpy as np from scipy.stats import entropy from sklearn.manifold import TSNE import pickle import pandas as pd import os from fuson_plm.utils.logging import log_update from fuson_plm.utils.visualizing import set_font, visualize_splits def main(): set_font() train_clusters = pd.read_csv('splits/train_cluster_split.csv') val_clusters = pd.read_csv('splits/val_cluster_split.csv') test_clusters = pd.read_csv('splits/test_cluster_split.csv') clusters = pd.concat([train_clusters,val_clusters,test_clusters]) fuson_db = pd.read_csv('fuson_db.csv') # Get the sequence IDs of all clustered benchmark sequences. benchmark_seq_ids = fuson_db.loc[fuson_db['benchmark'].notna()]['seq_id'] # Use benchmark_seq_ids to find which clusters contain benchmark sequences. benchmark_cluster_reps = clusters.loc[clusters['member seq_id'].isin(benchmark_seq_ids)]['representative seq_id'].unique().tolist() visualize_splits(train_clusters, val_clusters, test_clusters, benchmark_cluster_reps) ## Add seq_id to every source data file that is saved from visualize_splits seq_to_id_dict = dict(zip(fuson_db['aa_seq'],fuson_db['seq_id'])) files_to_edit = os.listdir("splits/split_vis") files_to_edit = [x for x in files_to_edit if x[-4::]==".csv"] log_update(f"Adding seq_ids to the following files: {files_to_edit}") for fname in files_to_edit: source_data_file = pd.read_csv(f"splits/split_vis/{fname}") if "sequence" in list(source_data_file.columns): source_data_file["seq_id"] = source_data_file["sequence"].map(seq_to_id_dict) source_data_file.drop(columns=['sequence']).to_csv(f"splits/split_vis/{fname}",index=False) if __name__ == "__main__": main()