|  | """ | 
					
						
						|  | This file defines the layout of the app including the header, sidebar, and tabs in the | 
					
						
						|  | main content area. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | import streamlit as st | 
					
						
						|  | import streamlit.components.v1 as components | 
					
						
						|  | from PIL import Image | 
					
						
						|  | import pandas as pd | 
					
						
						|  | import yaml | 
					
						
						|  |  | 
					
						
						|  | from src.data_preprocessing.create_descriptors import handle_inputs | 
					
						
						|  | from src.app.constants import (summary_text, | 
					
						
						|  | mhnfs_text, | 
					
						
						|  | citation_text, | 
					
						
						|  | few_shot_learning_text, | 
					
						
						|  | under_the_hood_text, | 
					
						
						|  | usage_text, | 
					
						
						|  | data_text, | 
					
						
						|  | trust_text, | 
					
						
						|  | example_trustworthy_text, | 
					
						
						|  | example_nottrustworthy_text) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | MAX_INPUT_LENGTH = 20 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class LayoutMaker(): | 
					
						
						|  | """ | 
					
						
						|  | This class includes all the design choices regarding the layout of the app. This | 
					
						
						|  | class can be used in the main file to define header, sidebar, and main content area. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__(self): | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.inputs = dict() | 
					
						
						|  | self.inputs_lists = dict() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.predictions = None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.buttons = dict() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.summary_text = summary_text | 
					
						
						|  | self.mhnfs_text = mhnfs_text | 
					
						
						|  | self.citation_text = citation_text | 
					
						
						|  | self.few_shot_learning_text = few_shot_learning_text | 
					
						
						|  | self.under_the_hood_text = under_the_hood_text | 
					
						
						|  | self.usage_text = usage_text | 
					
						
						|  | self.data_text = data_text | 
					
						
						|  | self.trust_text = trust_text | 
					
						
						|  | self.example_trustworthy_text = example_trustworthy_text | 
					
						
						|  | self.example_nottrustworthy_text = example_nottrustworthy_text | 
					
						
						|  |  | 
					
						
						|  | self.df_trustworthy = pd.read_csv("./assets/example_csv/predictions/" | 
					
						
						|  | "trustworthy_example.csv") | 
					
						
						|  | self.df_nottrustworthy = pd.read_csv("./assets/example_csv/predictions/" | 
					
						
						|  | "nottrustworthy_example.csv") | 
					
						
						|  |  | 
					
						
						|  | self.max_input_length = MAX_INPUT_LENGTH | 
					
						
						|  |  | 
					
						
						|  | def make_sidebar(self): | 
					
						
						|  | """ | 
					
						
						|  | This function defines the sidebar of the app. It includes the logo, query box, | 
					
						
						|  | support set boxes, and predict buttons. | 
					
						
						|  | It returns the stored inputs (for query and support set) and the buttons which | 
					
						
						|  | allow for user interactions. | 
					
						
						|  | """ | 
					
						
						|  | with st.sidebar: | 
					
						
						|  |  | 
					
						
						|  | logo = Image.open("./assets/logo.png") | 
					
						
						|  | st.image(logo) | 
					
						
						|  | st.divider() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self._make_query_box() | 
					
						
						|  | st.divider() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self._make_active_support_set_box() | 
					
						
						|  | st.divider() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self._make_inactive_support_set_box() | 
					
						
						|  | st.divider() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.buttons["predict"] = st.button("Predict...") | 
					
						
						|  | self.buttons["reset"] = st.button("Reset") | 
					
						
						|  |  | 
					
						
						|  | return self.inputs, self.buttons | 
					
						
						|  |  | 
					
						
						|  | def make_header(self): | 
					
						
						|  | """ | 
					
						
						|  | This function defines the header of the app. It consists only of a png image | 
					
						
						|  | in which the title and an overview is given. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | header_container = st.container() | 
					
						
						|  | with header_container: | 
					
						
						|  | header = Image.open("./assets/header.png") | 
					
						
						|  | st.image(header) | 
					
						
						|  |  | 
					
						
						|  | def make_main_content_area(self, | 
					
						
						|  | predictor, | 
					
						
						|  | inputs, | 
					
						
						|  | buttons, | 
					
						
						|  | create_prediction_df: callable, | 
					
						
						|  | create_molecule_grid_plot: callable): | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | tab1, tab2, tab3, tab4 = st.tabs(["Predictions", | 
					
						
						|  | "Paper / Cite", | 
					
						
						|  | "Additional Information", | 
					
						
						|  | "Examples"]) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | with tab1: | 
					
						
						|  | self._fill_tab_with_results_content(predictor, | 
					
						
						|  | inputs, | 
					
						
						|  | buttons, | 
					
						
						|  | create_prediction_df, | 
					
						
						|  | create_molecule_grid_plot) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | with tab2: | 
					
						
						|  | self._fill_paper_and_citation_tab() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | with tab3: | 
					
						
						|  | self._fill_more_explanations_tab() | 
					
						
						|  |  | 
					
						
						|  | with tab4: | 
					
						
						|  | self._fill_examples_tab() | 
					
						
						|  |  | 
					
						
						|  | def _make_query_box(self): | 
					
						
						|  | """ | 
					
						
						|  | This function | 
					
						
						|  | a) defines the query box and | 
					
						
						|  | b) stores the query input in the inputs dictionary | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | st.info(":blue[Molecules to predict:]", icon="β") | 
					
						
						|  |  | 
					
						
						|  | query_container = st.container() | 
					
						
						|  | with query_container: | 
					
						
						|  | input_choice = st.radio( | 
					
						
						|  | "Input your data in SMILES notation via:", ["Text box", "CSV upload"] | 
					
						
						|  | ) | 
					
						
						|  | if input_choice == "Text box": | 
					
						
						|  | query_input = st.text_area( | 
					
						
						|  | label="SMILES input for query molecules", | 
					
						
						|  | label_visibility="hidden", | 
					
						
						|  | key="query_textbox", | 
					
						
						|  | value= "Cc1nc(N2CCN(Cc3ccccc3)CC2)c(C#N)c(=O)n1CC(=O)O, " | 
					
						
						|  | "N#Cc1c(-c2ccccc2)nc(-c2cccc3c(Br)cccc23)n(CC(=O)O)c1=O, " | 
					
						
						|  | "Cc1nc(N2CCC(Cc3ccccc3)CC2)c(C#N)c(=O)n1CC(=O)O, " | 
					
						
						|  | "CC(C)Sc1nc(C(C)(C)C)nc(OCC(=O)O)c1C#N, " | 
					
						
						|  | "Cc1nc(NCc2cccnc2)cc(=O)n1CC(=O)O, " | 
					
						
						|  | "COC(=O)c1c(SC)nc(C2CCCCC2)n(CC(=O)O)c1=O, " | 
					
						
						|  | "Cc1nc(NCc2cccnc2)c(C#N)c(=O)n1CC(=O)O, " | 
					
						
						|  | "CC(C)c1nc(SCc2ccccc2)c(C#N)c(=O)n1CC(=O)O, " | 
					
						
						|  | "N#Cc1c(OCC(=O)O)nc(-c2cccc3ccccc23)nc1-c1ccccc1, " | 
					
						
						|  | "COc1ccc2c(C(=S)N(C)CC(=O)O)cccc2c1C(F)(F)F" | 
					
						
						|  | ) | 
					
						
						|  | elif input_choice == "CSV upload": | 
					
						
						|  | query_file = st.file_uploader(key="query_csv", | 
					
						
						|  | label = "CSV upload for query mols", | 
					
						
						|  | label_visibility="hidden") | 
					
						
						|  | if query_file is not None: | 
					
						
						|  | query_input = pd.read_csv(query_file) | 
					
						
						|  | else: query_input = None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.inputs["query"] = query_input | 
					
						
						|  |  | 
					
						
						|  | def _make_active_support_set_box(self): | 
					
						
						|  | """ | 
					
						
						|  | This function | 
					
						
						|  | a) defines the active support set box and | 
					
						
						|  | b) stores the active support set input in the inputs dictionary | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | st.info(":blue[Known active molecules:]", icon="β¨") | 
					
						
						|  | active_container = st.container() | 
					
						
						|  | with active_container: | 
					
						
						|  | active_input_choice = st.radio( | 
					
						
						|  | "Input your data in SMILES notation via:", | 
					
						
						|  | ["Text box", "CSV upload"], | 
					
						
						|  | key="active_input_choice", | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if active_input_choice == "Text box": | 
					
						
						|  | support_active_input = st.text_area( | 
					
						
						|  | label="SMILES input for active support set molecules", | 
					
						
						|  | label_visibility="hidden", | 
					
						
						|  | key="active_textbox", | 
					
						
						|  | value="CC(C)(C)c1nc(OCC(=O)O)c(C#N)c(SCC2CCCCC2)n1, " | 
					
						
						|  | "Cc1nc(NCC2CCCCC2)c(C#N)c(=O)n1CC(=O)O" | 
					
						
						|  | ) | 
					
						
						|  | elif active_input_choice == "CSV upload": | 
					
						
						|  | support_active_file = st.file_uploader( | 
					
						
						|  | key="support_active_csv", | 
					
						
						|  | label = "CSV upload for active support set molecules", | 
					
						
						|  | label_visibility="hidden" | 
					
						
						|  | ) | 
					
						
						|  | if support_active_file is not None: | 
					
						
						|  | support_active_input  = pd.read_csv(support_active_file) | 
					
						
						|  | else: support_active_input = None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.inputs["support_active"] = support_active_input | 
					
						
						|  |  | 
					
						
						|  | def _make_inactive_support_set_box(self): | 
					
						
						|  | st.info(":blue[Known inactive molecules:]", icon="β¨") | 
					
						
						|  | inactive_container = st.container() | 
					
						
						|  | with inactive_container: | 
					
						
						|  | inactive_input_choice = st.radio( | 
					
						
						|  | "Input your data in SMILES notation via:", | 
					
						
						|  | ["Text box", "CSV upload"], | 
					
						
						|  | key="inactive_input_choice", | 
					
						
						|  | ) | 
					
						
						|  | if inactive_input_choice == "Text box": | 
					
						
						|  | support_inactive_input  = st.text_area( | 
					
						
						|  | label="SMILES input for inactive support set molecules", | 
					
						
						|  | label_visibility="hidden", | 
					
						
						|  | key="inactive_textbox", | 
					
						
						|  | value="CSc1nc(C2CCCCC2)n(CC(=O)O)c(=O)c1S(=O)(=O)c1ccccc1, " | 
					
						
						|  | "CSc1nc(C)nc(OCC(=O)O)c1C#N" | 
					
						
						|  | ) | 
					
						
						|  | elif inactive_input_choice == "CSV upload": | 
					
						
						|  | support_inactive_file  = st.file_uploader( | 
					
						
						|  | key="support_inactive_csv", | 
					
						
						|  | label = "CSV upload for inactive support set molecules", | 
					
						
						|  | label_visibility="hidden" | 
					
						
						|  | ) | 
					
						
						|  | if support_inactive_file is not None: | 
					
						
						|  | support_inactive_input  = pd.read_csv( | 
					
						
						|  | support_inactive_file | 
					
						
						|  | ) | 
					
						
						|  | else: support_inactive_input = None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.inputs["support_inactive"] = support_inactive_input | 
					
						
						|  |  | 
					
						
						|  | def _fill_tab_with_results_content(self, predictor, inputs, buttons, | 
					
						
						|  | create_prediction_df, create_molecule_grid_plot): | 
					
						
						|  | tab_container = st.container() | 
					
						
						|  | with tab_container: | 
					
						
						|  |  | 
					
						
						|  | st.info(":blue[Summary:]", icon="π") | 
					
						
						|  | st.markdown(self.summary_text) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | st.info(":blue[Results:]",icon="π¨βπ»") | 
					
						
						|  |  | 
					
						
						|  | if buttons['predict']: | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if (inputs['query'] is None or | 
					
						
						|  | inputs['support_active'] is None or | 
					
						
						|  | inputs['support_inactive'] is None): | 
					
						
						|  | st.error("You didn't provide all necessary inputs.\n\n" | 
					
						
						|  | "Please provide all three necessary inputs via the " | 
					
						
						|  | "sidebar and hit the predict button again.") | 
					
						
						|  | else: | 
					
						
						|  |  | 
					
						
						|  | max_input_length = 0 | 
					
						
						|  | for key, input in inputs.items(): | 
					
						
						|  | input_list = handle_inputs(input) | 
					
						
						|  | self.inputs_lists[key] = input_list | 
					
						
						|  | max_input_length = max(max_input_length, len(input_list)) | 
					
						
						|  |  | 
					
						
						|  | if max_input_length > self.max_input_length: | 
					
						
						|  | st.error("You provided too many molecules. The number of " | 
					
						
						|  | "molecules for each input is restricted to " | 
					
						
						|  | f"{self.max_input_length}.\n\n" | 
					
						
						|  | "For larger screenings, we suggest to clone the repo " | 
					
						
						|  | "and to run the model locally.") | 
					
						
						|  | else: | 
					
						
						|  |  | 
					
						
						|  | progress_bar_text = ("I'm predicting activities. This might " | 
					
						
						|  | "need some minutes. Please wait...") | 
					
						
						|  | progress_bar = st.progress(50, text=progress_bar_text) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | df = self._predict_and_create_results_table(predictor, | 
					
						
						|  | inputs, | 
					
						
						|  | create_prediction_df) | 
					
						
						|  |  | 
					
						
						|  | progress_bar_text = ("Done. Here are the results:") | 
					
						
						|  | progress_bar = progress_bar.progress(100, text=progress_bar_text) | 
					
						
						|  | st.dataframe(df, use_container_width=True) | 
					
						
						|  |  | 
					
						
						|  | col1, col2, col3, col4 = st.columns([1,1,1,1]) | 
					
						
						|  |  | 
					
						
						|  | with col2: | 
					
						
						|  | self.buttons["download_results"] = st.download_button( | 
					
						
						|  | "Download predictions as CSV", | 
					
						
						|  | self._convert_df_to_binary(df), | 
					
						
						|  | file_name="predictions.csv", | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | with col3: | 
					
						
						|  | with open("inputs.yml", 'w') as fl: | 
					
						
						|  | self.buttons["download_inputs"] = st.download_button( | 
					
						
						|  | "Download inputs as YML", | 
					
						
						|  | self._convert_to_yml(self.inputs_lists), | 
					
						
						|  | file_name="inputs.yml", | 
					
						
						|  | ) | 
					
						
						|  | st.divider() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | st.info(":blue[Grid plot of the predicted molecules:]", | 
					
						
						|  | icon="π") | 
					
						
						|  | mol_html_grid = create_molecule_grid_plot(df) | 
					
						
						|  | components.html(mol_html_grid, height=1000, scrolling=True) | 
					
						
						|  |  | 
					
						
						|  | elif buttons['reset']: | 
					
						
						|  | self._reset() | 
					
						
						|  |  | 
					
						
						|  | def _fill_paper_and_citation_tab(self): | 
					
						
						|  | st.info(":blue[**Paper: Context-enriched molecule representations improve " | 
					
						
						|  | "few-shot drug discovery**]", icon="π") | 
					
						
						|  | st.markdown(self.mhnfs_text, unsafe_allow_html=True) | 
					
						
						|  | st.image("./assets/mhnfs_overview.png") | 
					
						
						|  | st.write("") | 
					
						
						|  | st.write("") | 
					
						
						|  | st.write("") | 
					
						
						|  | st.info(":blue[**Cite us / BibTex**]", icon="π") | 
					
						
						|  | st.markdown(self.citation_text) | 
					
						
						|  |  | 
					
						
						|  | def _fill_more_explanations_tab(self): | 
					
						
						|  | st.info(":blue[**Under the hood**]", icon="βοΈ") | 
					
						
						|  | st.markdown(self.under_the_hood_text, unsafe_allow_html=True) | 
					
						
						|  | st.write("") | 
					
						
						|  | st.write("") | 
					
						
						|  |  | 
					
						
						|  | st.info(":blue[**About few-shot learning and the model MHNfs**]", icon="π―") | 
					
						
						|  | st.markdown(self.few_shot_learning_text, unsafe_allow_html=True) | 
					
						
						|  | st.write("") | 
					
						
						|  | st.write("") | 
					
						
						|  |  | 
					
						
						|  | st.info(":blue[**Usage**]", icon="ποΈ") | 
					
						
						|  | st.markdown(self.usage_text, unsafe_allow_html=True) | 
					
						
						|  | st.write("") | 
					
						
						|  | st.write("") | 
					
						
						|  |  | 
					
						
						|  | st.info(":blue[**How to provide the data**]", icon="π") | 
					
						
						|  | st.markdown(self.data_text, unsafe_allow_html=True) | 
					
						
						|  | st.write("") | 
					
						
						|  | st.write("") | 
					
						
						|  |  | 
					
						
						|  | st.info(":blue[**When to trust the predictions**]", icon="π") | 
					
						
						|  | st.markdown(self.trust_text, unsafe_allow_html=True) | 
					
						
						|  |  | 
					
						
						|  | def _fill_examples_tab(self): | 
					
						
						|  | st.info(":blue[**Example for trustworthy predictions**]", icon="β
") | 
					
						
						|  | st.markdown(self.example_trustworthy_text, unsafe_allow_html=True) | 
					
						
						|  | st.dataframe(self.df_trustworthy, use_container_width=True) | 
					
						
						|  | st.markdown("**Plot: Predictions for active and inactive molecules (model AUC=" | 
					
						
						|  | "0.96**)") | 
					
						
						|  | prediction_plot_tw = Image.open("./assets/example_csv/predictions/" | 
					
						
						|  | "trustworthy_example.png") | 
					
						
						|  | st.image(prediction_plot_tw) | 
					
						
						|  | st.write("") | 
					
						
						|  | st.write("") | 
					
						
						|  |  | 
					
						
						|  | st.info(":blue[**Example for not trustworthy predictions**]", icon="βοΈ") | 
					
						
						|  | st.markdown(self.example_nottrustworthy_text, unsafe_allow_html=True) | 
					
						
						|  | st.dataframe(self.df_nottrustworthy, use_container_width=True) | 
					
						
						|  | st.markdown("**Plot: Predictions for active and inactive molecules (model AUC=" | 
					
						
						|  | "0.42**)") | 
					
						
						|  | prediction_plot_ntw = Image.open("./assets/example_csv/predictions/" | 
					
						
						|  | "nottrustworthy_example.png") | 
					
						
						|  | st.image(prediction_plot_ntw) | 
					
						
						|  |  | 
					
						
						|  | def _predict_and_create_results_table(self, | 
					
						
						|  | predictor, | 
					
						
						|  | inputs, | 
					
						
						|  | create_prediction_df: callable): | 
					
						
						|  |  | 
					
						
						|  | df = create_prediction_df(predictor, | 
					
						
						|  | inputs['query'], | 
					
						
						|  | inputs['support_active'], | 
					
						
						|  | inputs['support_inactive']) | 
					
						
						|  | return df | 
					
						
						|  |  | 
					
						
						|  | def _reset(self): | 
					
						
						|  | keys = list(st.session_state.keys()) | 
					
						
						|  | for key in keys: | 
					
						
						|  | st.session_state.pop(key) | 
					
						
						|  |  | 
					
						
						|  | def _convert_df_to_binary(_self, df): | 
					
						
						|  | return df.to_csv(index=False).encode('utf-8') | 
					
						
						|  |  | 
					
						
						|  | def _convert_to_yml(_self, inputs): | 
					
						
						|  | return yaml.dump(inputs) | 
					
						
						|  | content = """ | 
					
						
						|  | # Usage | 
					
						
						|  | As soon as you have a few active and inactive molecules for your task, you can | 
					
						
						|  | provide them here and make predictions for new molecules. | 
					
						
						|  |  | 
					
						
						|  | ## About few-shot learning and the model MHNfs | 
					
						
						|  | **Few-shot learning** is a machine learning sub-field which aims to provide | 
					
						
						|  | predictive models for scenarios in which only little data is known/available. | 
					
						
						|  |  | 
					
						
						|  | **MHNfs** is a few-shot learning model which is specifically designed for drug | 
					
						
						|  | discovery applications. It is built to use the input prompts in a way such that | 
					
						
						|  | the provided available knowledge - i.e. the known active and inactive molecules - | 
					
						
						|  | functions as context to predict the activity of the new requested molecules. | 
					
						
						|  | Precisely, the provided active and inactive molecules are associated with a | 
					
						
						|  | large set of general molecules - called context molecules - to enrich the | 
					
						
						|  | provided information and to remove spurious correlations arising from the | 
					
						
						|  | decoration of molecules. This is analogous to a Large Language Model which would | 
					
						
						|  | not only use the provided information in the current prompt as context but would | 
					
						
						|  | also have access to way more information, e.g. a prompting history. | 
					
						
						|  |  | 
					
						
						|  | ## How to provide the data | 
					
						
						|  | * Molecules have to be provided in SMILES format. | 
					
						
						|  | * You can provide the molecules via the text boxes or via CSV upload. | 
					
						
						|  | - Text box: Replace the pseudo input by directly typing your molecules into | 
					
						
						|  | the text box. Please separate the molecules by comma. | 
					
						
						|  | - CSV upload: Upload a CSV file with the molecules. | 
					
						
						|  | * The CSV file should include a smiles column (both upper and lower | 
					
						
						|  | case "SMILES" are accepted). | 
					
						
						|  | * All other columns will be ignored. | 
					
						
						|  |  | 
					
						
						|  | ## When to trust the predictions | 
					
						
						|  | Just like all other machine learning models, the performance of MHNfs varies | 
					
						
						|  | and, generally, the model works well if the task is somehow close to tasks which | 
					
						
						|  | were used to train the model. The model performance for very different tasks is | 
					
						
						|  | unclear and might be poor. | 
					
						
						|  |  | 
					
						
						|  | MHNfs was trained on a the FS-Mol dataset which includes 5120 tasks (Roughly | 
					
						
						|  | 5000 tasks were used for training, rest for evaluation). The training tasks are | 
					
						
						|  | listed here: https://github.com/microsoft/FS-Mol/tree/main/datasets/targets. | 
					
						
						|  | """ | 
					
						
						|  | return content |