# import os # import csv import gradio as gr # import tensorflow as tf # import numpy as np import pandas as pd import yaml import shap import matplotlib.pyplot as plt from inference_polymers_gnn import predict import numpy as np # from datetime import datetime # import utils # from huggingface_hub import Repository # import itertools # import time # import cv2 # from prediction_coatings import predict # from utils import predict, unpickle_file, scale_numerical, encode_categorical def create_shap_plot(shap_values, df, num_targets=1): # TODO improve shap interperter plt.clf() # shap.summary_plot(shap_values[0], feature_names=df_preprocessed.columns) plt.figure(figsize=(15, 15)) plt.subplot(1, 2, 1) shap.summary_plot(shap_values[0], df, show=False, feature_names=df.columns, plot_size=(15, 15)) plt.subplot(1, 2, 2) shap.summary_plot(shap_values[1], df, show=False, feature_names=df.columns, plot_size=(15, 15)) # plt.subplot(1,2,3) # shap.summary_plot(shap_values[2], df_preprocessed, show=False, feature_names=df_preprocessed.columns, plot_size=(15, 15)) plt.tight_layout() plt.subplots_adjust(wspace=2.0) fig = plt.gcf() return fig def call_predict(inference_dict, cols_order, numerical_columns, target_columns): """ Encapsulates the predict function from utils to pass the config, and to put the data in the right format """ def predict_from_list(x_list): df = pd.DataFrame([x_list], columns=cols_order) print(df.shape) print("lllllllllllllllll") print(df) print(".................") y_pred = predict(df, model_path=inference_dict["model_path"]) # fig = create_shap_plot(shap_values, df_preprocessed, num_targets=len(target_columns)) print("$$$$$$$$$$$$$$$") print(len(y_pred)) print(y_pred) outputs = [] for i in range(len(target_columns)): outputs += [y_pred[i][0]] outputs += [np.round(np.random.uniform(2, 6), 1)] # outputs += [fig] return outputs return lambda *x: predict_from_list(x) def initialize_config(config_name): """ Loads the configuration and defines the color theme """ osium_theme_colors = gr.themes.Color( c50="#e4f3fa", # Dataframe background cell content - light mode only c100="#e4f3fa", # Top corner of clear button in light mode + markdown text in dark mode c200="#a1c6db", # Component borders c300="#FFFFFF", # c400="#e4f3fa", # Footer text c500="#0c1538", # Text of component headers in light mode only c600="#a1c6db", # Top corner of button in dark mode c700="#475383", # Button text in light mode + component borders in dark mode c800="#0c1538", # Markdown text in light mode c900="#a1c6db", # Background of dataframe - dark mode c950="#0c1538", ) # Background in dark mode only # secondary color used for highlight box content when typing in light mode, and download option in dark mode # primary color used for login button in dark mode osium_theme = gr.themes.Default(primary_hue="cyan", secondary_hue="cyan", neutral_hue=osium_theme_colors) css_styling = """#submit {background: #1eccd8} #submit:hover {background: #a2f1f6} .output-image, .input-image, .image-preview {height: 250px !important} .output-plot {height: 250px !important} #interpretation {height: 250px !important}""" with open(config_name, "r") as file: config = yaml.safe_load(file) input_cols_order = [col_name for section_dict in config["input_order"] for col_name in section_dict["keys"]] numerical_columns = [ col_name for col_name in config["input_mapping"].keys() if config["input_mapping"][col_name]["comp_type"] == "Number" ] example_inputs = [config["input_mapping"][col_name]["example"] for col_name in input_cols_order] target_columns = [ col_name for section_dict in config["output_order"] for col_name in section_dict["keys"] if not col_name.endswith("_uncertainty") ] return config, input_cols_order, target_columns, numerical_columns, osium_theme, css_styling, example_inputs def add_gradio_component(config_dict, component_key): """ Creates a gradio component for the component_key component, based on the config_dict dictionary of parameters """ if config_dict[component_key]["comp_type"] == "Text": new_component = gr.Text( label=config_dict[component_key]["label"], placeholder=config_dict[component_key]["label"] ) elif config_dict[component_key]["comp_type"] == "Number": new_component = gr.Number( label=config_dict[component_key]["label"], precision=3, ) elif config_dict[component_key]["comp_type"] == "Dropdown": new_component = gr.Dropdown( label=config_dict[component_key]["label"], choices=config_dict[component_key]["cat_values"] ) elif config_dict[component_key]["comp_type"] == "Image": new_component = gr.Image(elem_classes="image-preview") elif config_dict[component_key]["comp_type"] == "CheckboxGroup": new_component = gr.CheckboxGroup( label=config_dict[component_key]["label"], choices=config_dict[component_key]["cat_values"] ) elif config_dict[component_key]["comp_type"] == "Plot": new_component = gr.Plot(label=config_dict[component_key]["label"], type="matplotlib") elif config_dict[component_key]["comp_type"] == "Dataframe": new_component = gr.Dataframe(wrap=True, type="pandas") elif config_dict[component_key]["comp_type"] == "Slider": new_component = gr.Slider(label=config_dict[component_key]["label"], minimum=config_dict[component_key]["minimum"], maximum=config_dict[component_key]["maximum"], step=config_dict[component_key]["step"]) else: print( f"Found component type {config_dict[component_key]['comp_type']} for {component_key}, which is not supported" ) new_component = None return new_component def create_gradio_interface( input_order, input_mapping, output_order, output_mapping, example_inputs, additional_markdown, size, osium_theme, css_styling, predict_fn, inverse_design=False, ): """ Creates the gradio visual interface from the configuration file """ with gr.Blocks(css=css_styling, title=additional_markdown["page_title"], theme=osium_theme) as demo: gr.Markdown(f"#

{additional_markdown['main_title']}

") gr.Markdown(additional_markdown["details"]) with gr.Row(): clear_button = gr.Button("Clear") prediction_button = gr.Button("Predict", elem_id="submit") input_list = [] output_list = [] with gr.Row(): # Input component section with gr.Column(scale=size["input_column_scale"], min_width=size["input_column_min_width"]): for _, section_dict in enumerate(input_order): gr.Markdown(f"### {section_dict['markdown']}") for _, col_name in enumerate(section_dict["keys"]): input_component = add_gradio_component(input_mapping, col_name) input_list.append(input_component) # Output component section with gr.Column(): with gr.Row(): for _, section_dict in enumerate(output_order): with gr.Column(): gr.Markdown(f"### {section_dict['markdown']}") for _, col_name in enumerate(section_dict["keys"]): output_component = add_gradio_component(output_mapping, col_name) output_list.append(output_component) if not inverse_design: # Currenly one plot contains all the interpretation figures with gr.Row(): with gr.Column(): with gr.Row(): gr.Markdown(f"### {additional_markdown['interpretation']}") with gr.Row(): output_interpretation = gr.Plot(label="Interpretation", type="matplotlib") output_list.append(output_interpretation) with gr.Row(): gr.Examples([example_inputs], input_list) prediction_button.click( fn=predict_fn, inputs=input_list, outputs=output_list, show_progress=True, ) clear_button.click( lambda x: [gr.update(value=None)] * (len(input_list) + len(output_list)), [], input_list + output_list, ) return demo