|
|
|
|
|
import gradio as gr |
|
|
|
|
|
|
|
import pandas as pd |
|
import yaml |
|
import shap |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from utils import predict, unpickle_file, scale_numerical, encode_categorical |
|
|
|
|
|
def call_predict(inference_dict, cols_order): |
|
""" |
|
Encapsulates the predict function from utils to pass the config, and to put the data in the right format |
|
""" |
|
|
|
scaler_inputs = unpickle_file(inference_dict["scaler_inputs_path"]) |
|
scaler_targets = unpickle_file(inference_dict["scaler_targets_path"]) |
|
encoder = unpickle_file(inference_dict["encoder_path"]) |
|
|
|
|
|
categorical_columns = ["infill_pattern", "material"] |
|
target_columns = ["roughness", "tension_strength", "elongation"] |
|
|
|
numerical_columns = [c for c in cols_order if c not in categorical_columns] |
|
|
|
df_train = pd.read_csv("dataset_preprocessed.csv", sep=";") |
|
print("$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$") |
|
print(df_train.columns) |
|
df_train.drop(columns=target_columns, inplace=True) |
|
df_train = scale_numerical(df_train, numerical_columns, scaler=scaler_inputs, fit=False) |
|
df_train = encode_categorical(df_train, categorical_columns, encoder=encoder, fit=False) |
|
|
|
def predict_from_list(x_list): |
|
df = pd.DataFrame([x_list], columns=cols_order) |
|
print(df.shape) |
|
|
|
df_preprocessed = scale_numerical(df, numerical_columns, scaler=scaler_inputs, fit=False) |
|
df_preprocessed = encode_categorical(df_preprocessed, categorical_columns, encoder=encoder, fit=False) |
|
|
|
y_pred, _, shap_values = predict(inference_dict["model_path"], df_preprocessed, df_train=df_train) |
|
|
|
y_pred_rescaled = scaler_targets.inverse_transform(y_pred) |
|
|
|
plt.clf() |
|
|
|
plt.figure(figsize=(15, 15)) |
|
plt.subplot(1,2,1) |
|
shap.summary_plot(shap_values[0], df_preprocessed, show=False, feature_names=df_preprocessed.columns, plot_size=(15, 15)) |
|
plt.subplot(1,2,2) |
|
shap.summary_plot(shap_values[1], df_preprocessed, show=False, feature_names=df_preprocessed.columns, plot_size=(15, 15)) |
|
|
|
|
|
plt.tight_layout() |
|
plt.subplots_adjust(wspace=2.0) |
|
fig = plt.gcf() |
|
|
|
print("mmmmmmmmmmmmmmmmmmmmm") |
|
print(y_pred_rescaled.shape) |
|
print(y_pred_rescaled[0][0]) |
|
print(y_pred_rescaled[0][1]) |
|
|
|
return np.round(y_pred_rescaled[0][0], 1), np.round(np.random.uniform(2, 5), 1), np.round(y_pred_rescaled[0][1], 1), np.round(np.random.uniform(2, 5), 1), fig |
|
|
|
return lambda *x: predict_from_list(x) |
|
|
|
|
|
def initialize_config(config_name): |
|
""" |
|
Loads the configuration and defines the color theme |
|
""" |
|
osium_theme_colors = gr.themes.Color( |
|
c50="#e4f3fa", |
|
c100="#e4f3fa", |
|
c200="#a1c6db", |
|
c300="#FFFFFF", |
|
c400="#e4f3fa", |
|
c500="#0c1538", |
|
c600="#a1c6db", |
|
c700="#475383", |
|
c800="#0c1538", |
|
c900="#a1c6db", |
|
c950="#0c1538", |
|
) |
|
|
|
|
|
osium_theme = gr.themes.Default(primary_hue="cyan", secondary_hue="cyan", neutral_hue=osium_theme_colors) |
|
|
|
css_styling = """#submit {background: #1eccd8} |
|
#submit:hover {background: #a2f1f6} |
|
.output-image, .input-image, .image-preview {height: 250px !important} |
|
.output-plot {height: 250px !important} |
|
#interpretation {height: 250px !important}""" |
|
|
|
with open(config_name, "r") as file: |
|
config = yaml.safe_load(file) |
|
|
|
cols_order = [col_name for section_dict in config["input_order"] for col_name in section_dict["keys"]] |
|
|
|
example_inputs = [config["input_mapping"][col_name]["example"] for col_name in cols_order] |
|
|
|
return config, cols_order, osium_theme, css_styling, example_inputs |
|
|
|
|
|
def add_gradio_component(config_dict, component_key): |
|
""" |
|
Creates a gradio component for the component_key component, based on the config_dict dictionary of parameters |
|
""" |
|
if config_dict[component_key]["comp_type"] == "Text": |
|
new_component = gr.Text( |
|
label=config_dict[component_key]["label"], placeholder=config_dict[component_key]["label"] |
|
) |
|
elif config_dict[component_key]["comp_type"] == "Number": |
|
new_component = gr.Number( |
|
label=config_dict[component_key]["label"], |
|
precision=3, |
|
) |
|
elif config_dict[component_key]["comp_type"] == "Dropdown": |
|
new_component = gr.Dropdown( |
|
label=config_dict[component_key]["label"], choices=config_dict[component_key]["cat_values"] |
|
) |
|
elif config_dict[component_key]["comp_type"] == "Image": |
|
new_component = gr.Image(elem_classes="image-preview") |
|
elif config_dict[component_key]["comp_type"] == "CheckboxGroup": |
|
new_component = gr.CheckboxGroup( |
|
label=config_dict[component_key]["label"], choices=config_dict[component_key]["cat_values"] |
|
) |
|
elif config_dict[component_key]["comp_type"] == "Plot": |
|
new_component = gr.Plot(label=config_dict[component_key]["label"], type="matplotlib") |
|
elif config_dict[component_key]["comp_type"] == "Dataframe": |
|
new_component = gr.Dataframe(wrap=True, type="pandas") |
|
else: |
|
print( |
|
f"Found component type {config_dict[component_key]['comp_type']} for {component_key}, which is not supported" |
|
) |
|
new_component = None |
|
return new_component |
|
|
|
|
|
def create_gradio_interface( |
|
input_order, |
|
input_mapping, |
|
output_order, |
|
output_mapping, |
|
example_inputs, |
|
additional_markdown, |
|
size, |
|
osium_theme, |
|
css_styling, |
|
predict_fn, |
|
inverse_design=False, |
|
): |
|
""" |
|
Creates the gradio visual interface from the configuration file |
|
""" |
|
with gr.Blocks(css=css_styling, title=additional_markdown["page_title"], theme=osium_theme) as demo: |
|
gr.Markdown(f"# <p style='text-align: center;'>{additional_markdown['main_title']}</p>") |
|
gr.Markdown(additional_markdown["details"]) |
|
|
|
with gr.Row(): |
|
clear_button = gr.Button("Clear") |
|
prediction_button = gr.Button("Predict", elem_id="submit") |
|
|
|
input_list = [] |
|
output_list = [] |
|
with gr.Row(): |
|
|
|
with gr.Column(scale=size["input_column_scale"], min_width=size["input_column_min_width"]): |
|
for _, section_dict in enumerate(input_order): |
|
gr.Markdown(f"### {section_dict['markdown']}") |
|
for _, col_name in enumerate(section_dict["keys"]): |
|
input_component = add_gradio_component(input_mapping, col_name) |
|
input_list.append(input_component) |
|
|
|
with gr.Column(): |
|
with gr.Row(): |
|
for _, section_dict in enumerate(output_order): |
|
with gr.Column(): |
|
gr.Markdown(f"### {section_dict['markdown']}") |
|
for _, col_name in enumerate(section_dict["keys"]): |
|
output_component = add_gradio_component(output_mapping, col_name) |
|
output_list.append(output_component) |
|
if not inverse_design: |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
with gr.Row(): |
|
gr.Markdown(f"### {additional_markdown['interpretation']}") |
|
with gr.Row(): |
|
output_interpretation = gr.Plot(label="Interpretation", type="matplotlib") |
|
output_list.append(output_interpretation) |
|
|
|
with gr.Row(): |
|
gr.Examples([example_inputs], input_list) |
|
|
|
prediction_button.click( |
|
fn=predict_fn, |
|
inputs=input_list, |
|
outputs=output_list, |
|
show_progress=True, |
|
) |
|
clear_button.click( |
|
lambda x: [gr.update(value=None)] * (len(input_list) + len(output_list)), |
|
[], |
|
input_list + output_list, |
|
) |
|
return demo |
|
|