File size: 9,387 Bytes
8da763b
 
 
 
 
 
 
 
47724d8
6e343b0
c14b7bc
8da763b
 
 
 
 
 
 
 
628fe32
8da763b
 
 
 
 
 
847f816
4f20fa3
 
 
3fb7dbc
04a60b4
 
3d2e726
 
04a60b4
30bb0f9
0e5ffa3
5b53bce
 
30bb0f9
 
 
847f816
8da763b
 
 
 
cec0cf0
4a2ab4d
847f816
30bb0f9
847f816
 
47724d8
 
6d548db
 
 
e0aa30b
6d548db
e0aa30b
b6995e3
 
6d548db
 
47724d8
31f2e06
 
325df2e
3d2e726
 
 
2d8f692
8da763b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
# import os
# import csv
import gradio as gr

# import tensorflow as tf
# import numpy as np
import pandas as pd
import yaml
import shap
import matplotlib.pyplot as plt
import numpy as np

# from datetime import datetime
# import utils
# from huggingface_hub import Repository
# import itertools
# import time
# import cv2
# from prediction_coatings import predict
from utils import predict, unpickle_file, scale_numerical, encode_categorical


def call_predict(inference_dict, cols_order):
    """
    Encapsulates the predict function from utils to pass the config, and to put the data in the right format
    """

    scaler_inputs = unpickle_file(inference_dict["scaler_inputs_path"])
    scaler_targets = unpickle_file(inference_dict["scaler_targets_path"])
    encoder = unpickle_file(inference_dict["encoder_path"])
    # explainer = unpickle_file(inference_dict["explainer_path"])

    categorical_columns = ["infill_pattern", "material"]
    target_columns = ["roughness", "tension_strength", "elongation"]
    # target_columns = ["roughness", "tension_strength"]
    numerical_columns = [c for c in cols_order if c not in categorical_columns]

    df_train = pd.read_csv("dataset_preprocessed.csv", sep=";")
    print("$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$")
    print(df_train.columns)
    df_train.drop(columns=target_columns, inplace=True)
    df_train = scale_numerical(df_train, numerical_columns, scaler=scaler_inputs, fit=False)
    df_train = encode_categorical(df_train, categorical_columns, encoder=encoder, fit=False)

    def predict_from_list(x_list):
        df = pd.DataFrame([x_list], columns=cols_order)
        print(df.shape)

        df_preprocessed = scale_numerical(df, numerical_columns, scaler=scaler_inputs, fit=False)
        df_preprocessed = encode_categorical(df_preprocessed, categorical_columns, encoder=encoder, fit=False)

        y_pred, _, shap_values = predict(inference_dict["model_path"], df_preprocessed, df_train=df_train)

        y_pred_rescaled = scaler_targets.inverse_transform(y_pred)

        plt.clf()
        # shap.summary_plot(shap_values[0], feature_names=df_preprocessed.columns)
        plt.figure(figsize=(15, 15))
        plt.subplot(1,2,1)
        shap.summary_plot(shap_values[0], df_preprocessed, show=False, feature_names=df_preprocessed.columns, plot_size=(15, 15))
        plt.subplot(1,2,2)
        shap.summary_plot(shap_values[1], df_preprocessed, show=False, feature_names=df_preprocessed.columns, plot_size=(15, 15))
        # plt.subplot(1,2,3)
        # shap.summary_plot(shap_values[2], df_preprocessed, show=False, feature_names=df_preprocessed.columns, plot_size=(15, 15))
        plt.tight_layout()
        plt.subplots_adjust(wspace=2.0)
        fig = plt.gcf()

        print("mmmmmmmmmmmmmmmmmmmmm")
        print(y_pred_rescaled.shape)
        print(y_pred_rescaled[0][0])
        print(y_pred_rescaled[0][1])
        # return y_pred_rescaled[0][0], 10, y_pred_rescaled[0][1], 10, y_pred_rescaled[0][2], 10, fig
        return np.round(y_pred_rescaled[0][0], 1), np.round(np.random.uniform(2, 5), 1), np.round(y_pred_rescaled[0][1], 1), np.round(np.random.uniform(2, 5), 1), fig

    return lambda *x: predict_from_list(x)


def initialize_config(config_name):
    """
    Loads the configuration and defines the color theme
    """
    osium_theme_colors = gr.themes.Color(
        c50="#e4f3fa",  # Dataframe background cell content - light mode only
        c100="#e4f3fa",  # Top corner of clear button in light mode + markdown text in dark mode
        c200="#a1c6db",  # Component borders
        c300="#FFFFFF",  #
        c400="#e4f3fa",  # Footer text
        c500="#0c1538",  # Text of component headers in light mode only
        c600="#a1c6db",  # Top corner of button in dark mode
        c700="#475383",  # Button text in light mode + component borders in dark mode
        c800="#0c1538",  # Markdown text in light mode
        c900="#a1c6db",  # Background of dataframe - dark mode
        c950="#0c1538",
    )  # Background in dark mode only
    # secondary color used for highlight box content when typing in light mode, and download option in dark mode
    # primary color used for login button in dark mode
    osium_theme = gr.themes.Default(primary_hue="cyan", secondary_hue="cyan", neutral_hue=osium_theme_colors)

    css_styling = """#submit {background: #1eccd8} 
    #submit:hover {background: #a2f1f6} 
    .output-image, .input-image, .image-preview {height: 250px !important}
    .output-plot {height: 250px !important}
    #interpretation {height: 250px !important}"""

    with open(config_name, "r") as file:
        config = yaml.safe_load(file)

    cols_order = [col_name for section_dict in config["input_order"] for col_name in section_dict["keys"]]

    example_inputs = [config["input_mapping"][col_name]["example"] for col_name in cols_order]

    return config, cols_order, osium_theme, css_styling, example_inputs


def add_gradio_component(config_dict, component_key):
    """
    Creates a gradio component for the component_key component, based on the config_dict dictionary of parameters
    """
    if config_dict[component_key]["comp_type"] == "Text":
        new_component = gr.Text(
            label=config_dict[component_key]["label"], placeholder=config_dict[component_key]["label"]
        )
    elif config_dict[component_key]["comp_type"] == "Number":
        new_component = gr.Number(
            label=config_dict[component_key]["label"],
            precision=3,
        )
    elif config_dict[component_key]["comp_type"] == "Dropdown":
        new_component = gr.Dropdown(
            label=config_dict[component_key]["label"], choices=config_dict[component_key]["cat_values"]
        )
    elif config_dict[component_key]["comp_type"] == "Image":
        new_component = gr.Image(elem_classes="image-preview")
    elif config_dict[component_key]["comp_type"] == "CheckboxGroup":
        new_component = gr.CheckboxGroup(
            label=config_dict[component_key]["label"], choices=config_dict[component_key]["cat_values"]
        )
    elif config_dict[component_key]["comp_type"] == "Plot":
        new_component = gr.Plot(label=config_dict[component_key]["label"], type="matplotlib")
    elif config_dict[component_key]["comp_type"] == "Dataframe":
        new_component = gr.Dataframe(wrap=True, type="pandas")
    else:
        print(
            f"Found component type {config_dict[component_key]['comp_type']} for {component_key}, which is not supported"
        )
        new_component = None
    return new_component


def create_gradio_interface(
    input_order,
    input_mapping,
    output_order,
    output_mapping,
    example_inputs,
    additional_markdown,
    size,
    osium_theme,
    css_styling,
    predict_fn,
    inverse_design=False,
):
    """
    Creates the gradio visual interface from the configuration file
    """
    with gr.Blocks(css=css_styling, title=additional_markdown["page_title"], theme=osium_theme) as demo:
        gr.Markdown(f"# <p style='text-align: center;'>{additional_markdown['main_title']}</p>")
        gr.Markdown(additional_markdown["details"])

        with gr.Row():
            clear_button = gr.Button("Clear")
            prediction_button = gr.Button("Predict", elem_id="submit")

        input_list = []
        output_list = []
        with gr.Row():
            # Input component section
            with gr.Column(scale=size["input_column_scale"], min_width=size["input_column_min_width"]):
                for _, section_dict in enumerate(input_order):
                    gr.Markdown(f"### {section_dict['markdown']}")
                    for _, col_name in enumerate(section_dict["keys"]):
                        input_component = add_gradio_component(input_mapping, col_name)
                        input_list.append(input_component)
            # Output component section
            with gr.Column():
                with gr.Row():
                    for _, section_dict in enumerate(output_order):
                        with gr.Column():
                            gr.Markdown(f"### {section_dict['markdown']}")
                            for _, col_name in enumerate(section_dict["keys"]):
                                output_component = add_gradio_component(output_mapping, col_name)
                                output_list.append(output_component)
                if not inverse_design:
                    # Currenly one plot contains all the interpretation figures
                    with gr.Row():
                        with gr.Column():
                            with gr.Row():
                                gr.Markdown(f"### {additional_markdown['interpretation']}")
                            with gr.Row():
                                output_interpretation = gr.Plot(label="Interpretation", type="matplotlib")
                                output_list.append(output_interpretation)

        with gr.Row():
            gr.Examples([example_inputs], input_list)

        prediction_button.click(
            fn=predict_fn,
            inputs=input_list,
            outputs=output_list,
            show_progress=True,
        )
        clear_button.click(
            lambda x: [gr.update(value=None)] * (len(input_list) + len(output_list)),
            [],
            input_list + output_list,
        )
    return demo