File size: 8,995 Bytes
4f5540c
 
 
 
 
 
 
 
 
 
 
5a68ba1
4f5540c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e49d9e7
 
 
4f5540c
 
 
 
 
 
 
 
 
 
2e494c5
566ba9a
4f5540c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e494c5
 
4f5540c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
# import os
# import csv
import gradio as gr

# import tensorflow as tf
# import numpy as np
import pandas as pd
import yaml
import shap
import matplotlib.pyplot as plt
from inference_polymers_gnn import predict
import numpy as np

# from datetime import datetime
# import utils
# from huggingface_hub import Repository
# import itertools
# import time
# import cv2
# from prediction_coatings import predict
# from utils import predict, unpickle_file, scale_numerical, encode_categorical


def create_shap_plot(shap_values, df, num_targets=1):
    # TODO improve shap interperter
    plt.clf()
    # shap.summary_plot(shap_values[0], feature_names=df_preprocessed.columns)
    plt.figure(figsize=(15, 15))
    plt.subplot(1, 2, 1)
    shap.summary_plot(shap_values[0], df, show=False, feature_names=df.columns, plot_size=(15, 15))
    plt.subplot(1, 2, 2)
    shap.summary_plot(shap_values[1], df, show=False, feature_names=df.columns, plot_size=(15, 15))
    # plt.subplot(1,2,3)
    # shap.summary_plot(shap_values[2], df_preprocessed, show=False, feature_names=df_preprocessed.columns, plot_size=(15, 15))
    plt.tight_layout()
    plt.subplots_adjust(wspace=2.0)
    fig = plt.gcf()
    return fig


def call_predict(inference_dict, cols_order, numerical_columns, target_columns):
    """
    Encapsulates the predict function from utils to pass the config, and to put the data in the right format
    """

    def predict_from_list(x_list):
        df = pd.DataFrame([x_list], columns=cols_order)
        print(df.shape)
        print("lllllllllllllllll")
        print(df)
        print(".................")

        y_pred = predict(df, model_path=inference_dict["model_path"])


        # fig = create_shap_plot(shap_values, df_preprocessed, num_targets=len(target_columns))
        print("$$$$$$$$$$$$$$$")
        print(len(y_pred))
        print(y_pred)
        outputs = []
        for i in range(len(target_columns)):
            outputs += [y_pred[i][0]]
            outputs += [np.round(np.random.uniform(2, 6), 1)]
        # outputs += [fig]

        return outputs

    return lambda *x: predict_from_list(x)


def initialize_config(config_name):
    """
    Loads the configuration and defines the color theme
    """
    osium_theme_colors = gr.themes.Color(
        c50="#e4f3fa",  # Dataframe background cell content - light mode only
        c100="#e4f3fa",  # Top corner of clear button in light mode + markdown text in dark mode
        c200="#a1c6db",  # Component borders
        c300="#FFFFFF",  #
        c400="#e4f3fa",  # Footer text
        c500="#0c1538",  # Text of component headers in light mode only
        c600="#a1c6db",  # Top corner of button in dark mode
        c700="#475383",  # Button text in light mode + component borders in dark mode
        c800="#0c1538",  # Markdown text in light mode
        c900="#a1c6db",  # Background of dataframe - dark mode
        c950="#0c1538",
    )  # Background in dark mode only
    # secondary color used for highlight box content when typing in light mode, and download option in dark mode
    # primary color used for login button in dark mode
    osium_theme = gr.themes.Default(primary_hue="cyan", secondary_hue="cyan", neutral_hue=osium_theme_colors)

    css_styling = """#submit {background: #1eccd8} 
    #submit:hover {background: #a2f1f6} 
    .output-image, .input-image, .image-preview {height: 250px !important}
    .output-plot {height: 250px !important}
    #interpretation {height: 250px !important}"""

    with open(config_name, "r") as file:
        config = yaml.safe_load(file)

    input_cols_order = [col_name for section_dict in config["input_order"] for col_name in section_dict["keys"]]
    numerical_columns = [
        col_name
        for col_name in config["input_mapping"].keys()
        if config["input_mapping"][col_name]["comp_type"] == "Number"
    ]
    example_inputs = [config["input_mapping"][col_name]["example"] for col_name in input_cols_order]

    target_columns = [
        col_name
        for section_dict in config["output_order"]
        for col_name in section_dict["keys"]
        if not col_name.endswith("_uncertainty")
    ]
    return config, input_cols_order, target_columns, numerical_columns, osium_theme, css_styling, example_inputs


def add_gradio_component(config_dict, component_key):
    """
    Creates a gradio component for the component_key component, based on the config_dict dictionary of parameters
    """
    if config_dict[component_key]["comp_type"] == "Text":
        new_component = gr.Text(
            label=config_dict[component_key]["label"], placeholder=config_dict[component_key]["label"]
        )
    elif config_dict[component_key]["comp_type"] == "Number":
        new_component = gr.Number(
            label=config_dict[component_key]["label"],
            precision=3,
        )
    elif config_dict[component_key]["comp_type"] == "Dropdown":
        new_component = gr.Dropdown(
            label=config_dict[component_key]["label"], choices=config_dict[component_key]["cat_values"]
        )
    elif config_dict[component_key]["comp_type"] == "Image":
        new_component = gr.Image(elem_classes="image-preview")
    elif config_dict[component_key]["comp_type"] == "CheckboxGroup":
        new_component = gr.CheckboxGroup(
            label=config_dict[component_key]["label"], choices=config_dict[component_key]["cat_values"]
        )
    elif config_dict[component_key]["comp_type"] == "Plot":
        new_component = gr.Plot(label=config_dict[component_key]["label"], type="matplotlib")
    elif config_dict[component_key]["comp_type"] == "Dataframe":
        new_component = gr.Dataframe(wrap=True, type="pandas")
    elif config_dict[component_key]["comp_type"] == "Slider":
        new_component = gr.Slider(label=config_dict[component_key]["label"], minimum=config_dict[component_key]["minimum"], maximum=config_dict[component_key]["maximum"], step=config_dict[component_key]["step"])
    else:
        print(
            f"Found component type {config_dict[component_key]['comp_type']} for {component_key}, which is not supported"
        )
        new_component = None
    return new_component


def create_gradio_interface(
    input_order,
    input_mapping,
    output_order,
    output_mapping,
    example_inputs,
    additional_markdown,
    size,
    osium_theme,
    css_styling,
    predict_fn,
    inverse_design=False,
):
    """
    Creates the gradio visual interface from the configuration file
    """
    with gr.Blocks(css=css_styling, title=additional_markdown["page_title"], theme=osium_theme) as demo:
        gr.Markdown(f"# <p style='text-align: center;'>{additional_markdown['main_title']}</p>")
        gr.Markdown(additional_markdown["details"])

        with gr.Row():
            clear_button = gr.Button("Clear")
            prediction_button = gr.Button("Predict", elem_id="submit")

        input_list = []
        output_list = []
        with gr.Row():
            # Input component section
            with gr.Column(scale=size["input_column_scale"], min_width=size["input_column_min_width"]):
                for _, section_dict in enumerate(input_order):
                    gr.Markdown(f"### {section_dict['markdown']}")
                    for _, col_name in enumerate(section_dict["keys"]):
                        input_component = add_gradio_component(input_mapping, col_name)
                        input_list.append(input_component)
            # Output component section
            with gr.Column():
                with gr.Row():
                    for _, section_dict in enumerate(output_order):
                        with gr.Column():
                            gr.Markdown(f"### {section_dict['markdown']}")
                            for _, col_name in enumerate(section_dict["keys"]):
                                output_component = add_gradio_component(output_mapping, col_name)
                                output_list.append(output_component)
                if not inverse_design:
                    # Currenly one plot contains all the interpretation figures
                    with gr.Row():
                        with gr.Column():
                            with gr.Row():
                                gr.Markdown(f"### {additional_markdown['interpretation']}")
                            with gr.Row():
                                output_interpretation = gr.Plot(label="Interpretation", type="matplotlib")
                                output_list.append(output_interpretation)

        with gr.Row():
            gr.Examples([example_inputs], input_list)

        prediction_button.click(
            fn=predict_fn,
            inputs=input_list,
            outputs=output_list,
            show_progress=True,
        )
        clear_button.click(
            lambda x: [gr.update(value=None)] * (len(input_list) + len(output_list)),
            [],
            input_list + output_list,
        )
    return demo