Spaces:
Sleeping
Sleeping
import gradio as gr | |
import pandas as pd | |
import os | |
import matplotlib.pyplot as plt | |
# from utils import unpickle_file, scale_numerical_w_missing | |
# from utils import unpickle_file | |
import plotly.express as px | |
from gradio_utils import load_theme | |
# from alloy_data_preprocessing import add_physics_features | |
# from inference_model_main import predict_all_results | |
import plotly.graph_objects as go | |
import yaml | |
import numpy as np | |
# def run_predictions( | |
# df, | |
# scaler_inputs_path, | |
# main_model_path, | |
# main_input_cols_order, | |
# intermediate_model_path, | |
# intermediate_results_columns, | |
# ): | |
# """ | |
# Scale the data and runs the predictions on the intermediate columns and the final properties | |
# """ | |
# scaler_inputs = unpickle_file(scaler_inputs_path) | |
# df_p = add_physics_features(df) | |
# df_scaled = scale_numerical_w_missing(df_p, scaler_inputs.feature_names_in_, scaler_inputs) | |
# y_pred, uncertainty = predict_all_results( | |
# df_scaled, | |
# main_model_path, | |
# main_input_cols_order, | |
# scaler_targets_main=None, | |
# intermediate_model_path=intermediate_model_path, | |
# intermediate_results_columns=intermediate_results_columns, | |
# return_uncertainty=True, | |
# uncertainty_type="weighted", | |
# ) | |
# return y_pred, uncertainty | |
# def create_domain_space(space_dict, inference_dict, df_path): | |
# """ | |
# Create the dataframe containing the pre-computed values for the uncertainty | |
# """ | |
# input_cols = ["%C", "%Co", "%Cr", "%V", "%Mo", "%W", "Temperature_C"] | |
# c = space_dict["%C"]["value"] | |
# co = space_dict["%Co"]["value"] | |
# cr = space_dict["%Cr"]["value"] | |
# v = space_dict["%V"]["value"] | |
# mo = space_dict["%Mo"]["value"] | |
# w = space_dict["%W"]["value"] | |
# temp = 538 | |
# space_list = [ | |
# [ic, ico, icr, iv, imo, iw, temp] | |
# for ic in np.arange( | |
# space_dict["%C"]["min"], space_dict["%C"]["max"] + space_dict["%C"]["step"], space_dict["%C"]["step"] | |
# ) | |
# for ico in np.arange( | |
# space_dict["%Co"]["min"], space_dict["%Co"]["max"] + space_dict["%Co"]["step"], space_dict["%Co"]["step"] | |
# ) | |
# for icr in np.arange( | |
# space_dict["%Cr"]["min"], space_dict["%Cr"]["max"] + space_dict["%Cr"]["step"], space_dict["%Cr"]["step"] | |
# ) | |
# for iv in np.arange( | |
# space_dict["%V"]["min"], space_dict["%V"]["max"] + space_dict["%V"]["step"], space_dict["%V"]["step"] | |
# ) | |
# for imo in np.arange( | |
# space_dict["%Mo"]["min"], space_dict["%Mo"]["max"] + space_dict["%Mo"]["step"], space_dict["%Mo"]["step"] | |
# ) | |
# for iw in np.arange( | |
# space_dict["%W"]["min"], space_dict["%W"]["max"] + space_dict["%W"]["step"], space_dict["%W"]["step"] | |
# ) | |
# ] | |
# df_synth = pd.DataFrame(space_list, columns=input_cols) | |
# print("Uncertainty space will be computed on:") | |
# print(df_synth.shape) | |
# model_path = inference_dict["final_prediction"]["model_path"] | |
# print("Model used:", model_path) | |
# scaler_inputs_intermediate = inference_dict["scaler_inputs_path"] | |
# intermediate_cols = [ | |
# "%C matrice", | |
# "%Co matrice", | |
# "%Cr matrice", | |
# "%V matrice", | |
# "%Mo matrice", | |
# "%W matrice", | |
# "M6C", | |
# "M23C6", | |
# "FCCA1#2", | |
# "M2C", | |
# "MC - SHP", | |
# "MC ETA", | |
# ] | |
# scaler_inputs_main = unpickle_file(inference_dict["final_prediction"]["scaler_inputs_path"]) | |
# intermediate_model_path_dict = inference_dict["multiple_model_path"] | |
# y_pred, uncertainty = run_predictions( | |
# df_synth, | |
# scaler_inputs_intermediate, | |
# model_path, | |
# scaler_inputs_main.feature_names_in_, | |
# intermediate_model_path_dict, | |
# intermediate_cols, | |
# ) | |
# df_synth_pred = df_synth.copy() | |
# df_synth_pred["y_pred"] = y_pred | |
# df_synth_pred["uncertainty_not_scaled"] = uncertainty | |
# min_uncertainty, max_uncertainty = ( | |
# df_synth_pred["uncertainty_not_scaled"].min(), | |
# df_synth_pred["uncertainty_not_scaled"].max(), | |
# ) | |
# df_synth_pred["uncertainty"] = (df_synth_pred["uncertainty_not_scaled"] - min_uncertainty) / ( | |
# max_uncertainty - min_uncertainty | |
# ) | |
# print("Domain space created") | |
# print("-----------------------------") | |
# print("Saving dataframe at", df_path) | |
# df_synth_pred.to_csv(df_path, sep=";", index=False) | |
# return df_synth_pred | |
def load_domain_space(df_path): | |
df_synth_pred = pd.read_csv(df_path, sep=";") | |
print("---------------------------") | |
print("min max", df_synth_pred["uncertainty_not_scaled"].min(), df_synth_pred["uncertainty_not_scaled"].max()) | |
print("Design space dataframe", df_synth_pred.shape) | |
print("---------------------------") | |
return df_synth_pred | |
def filter_dataframe(params_list, df): | |
col1_name = params_list[0] | |
col1_value = params_list[1] | |
col2_name = params_list[2] | |
col2_value = params_list[3] | |
col3_name = params_list[4] | |
col3_value = params_list[5] | |
df_filtered = df[(df[col1_name] == col1_value) & (df[col2_name] == col2_value) & (df[col3_name] == col3_value)] | |
return df_filtered, [col1_name, col2_name, col3_name] | |
def interpolate_space(df, col_name, value): | |
""" | |
Interpolate the uncertainty space for values within the range but not direcly pre-computed | |
""" | |
# No need to interpolate, uncertainty for this value is already pre-computed | |
if value in list(df[col_name]): | |
print("value in column", col_name, value) | |
return df[df[col_name] == value] | |
df_interpolated = df.copy() | |
# Find the closest values in the dataframe to the pass value | |
k_closest = 2 | |
df_interpolated["distance"] = np.abs(df[col_name] - value) | |
print("Looking for closest values") | |
values_closest = list( | |
df_interpolated.sort_values(by=["distance"], ascending=True)[col_name].iloc[0:k_closest].values | |
) | |
input_cols = ["%C", "%Co", "%Cr", "%V", "%Mo", "%W", "Temperature_C"] | |
agg_cols = input_cols.copy() | |
agg_cols.remove(col_name) | |
print(agg_cols) | |
df_tmp = df[df[col_name].isin(values_closest)] | |
df_tmp = df_tmp.groupby(agg_cols).mean().reset_index().drop(columns=col_name) | |
df_tmp[col_name] = value | |
print("==============") | |
print("Value interpolated", col_name, value) | |
print(df_tmp.shape) | |
return df_tmp | |
def interpolate_all(params_list, df): | |
print(df.shape) | |
df_filtered = df.copy() | |
filter_cols = [] | |
for i in range(0, len(params_list), 2): | |
df_filtered = interpolate_space(df_filtered, params_list[i], params_list[i + 1]) | |
filter_cols.append(params_list[i]) | |
print(df_filtered.shape) | |
return df_filtered, filter_cols | |
def make_domain_plot(df_synth_pred, explored_domain_space, x_list, target="y_pred"): | |
""" | |
Create a plot with the uncertainty space and the training space | |
""" | |
uncertainty_min = df_synth_pred[target].min() | |
uncertainty_max = df_synth_pred[target].max() | |
# df_synth_pred2, filter_cols = filter_dataframe(x_list[:6], df_synth_pred) | |
df_synth_pred2, filter_cols = interpolate_all(x_list[:6], df_synth_pred) | |
cols_for_plot = [c for c in df_synth_pred.columns if c not in filter_cols + ["Temperature_C"]] | |
x_col, y_col, z_col = cols_for_plot[0], cols_for_plot[1], cols_for_plot[2] | |
fig = px.scatter_3d( | |
df_synth_pred2, | |
x=x_col, | |
y=y_col, | |
z=z_col, | |
color=target, | |
range_color=[uncertainty_min, uncertainty_max], | |
hover_data={target: ":.3f", "uncertainty": ":.3f"}, | |
) | |
# Filter domain space | |
for i in [0, 2, 4]: | |
if (x_list[i + 1] < explored_domain_space[x_list[i]]["min"]) or ( | |
x_list[i + 1] > explored_domain_space[x_list[i]]["max"] | |
): | |
return fig | |
# Add explored domain space | |
x_cube = ( | |
np.array([0, 0, 1, 1, 0, 0, 1, 1]) * (explored_domain_space[x_col]["max"] - explored_domain_space[x_col]["min"]) | |
+ explored_domain_space[x_col]["min"] | |
) | |
y_cube = ( | |
np.array([0, 1, 1, 0, 0, 1, 1, 0]) * (explored_domain_space[y_col]["max"] - explored_domain_space[y_col]["min"]) | |
+ explored_domain_space[y_col]["min"] | |
) | |
z_cube = ( | |
np.array([0, 0, 0, 0, 1, 1, 1, 1]) * (explored_domain_space[z_col]["max"] - explored_domain_space[z_col]["min"]) | |
+ explored_domain_space[z_col]["min"] | |
) | |
# Plot domain space as a cube | |
trace4 = go.Mesh3d( | |
# 8 vertices of a cube | |
x=x_cube.tolist(), | |
y=y_cube.tolist(), | |
z=z_cube.tolist(), | |
# Keep these values (i, j, k) to get a cube (represent the vertices) | |
i=[7, 0, 0, 0, 4, 4, 6, 6, 4, 0, 3, 2], | |
j=[3, 4, 1, 2, 5, 6, 5, 2, 0, 1, 6, 3], | |
k=[0, 7, 2, 3, 6, 7, 1, 1, 5, 5, 7, 6], | |
opacity=0.3, | |
color="turquoise", | |
flatshading=True, | |
name="Training space", | |
hovertemplate=x_col + ": %{x:.2f}<br>" + y_col + ": %{y:.2f}<br>" + z_col + ": %{z:.2f}" | |
# vertexcolor=["black"] * 12, | |
) | |
fig.add_trace(trace4) | |
return fig | |
def create_plot(df_synth_pred, explored_space_dict, target): | |
""" | |
Wrapper to create the function to generate the plotly plots | |
""" | |
# Create plotly plot | |
def plot_figure(x): | |
x_params = x[:6] | |
fig = make_domain_plot(df_synth_pred, explored_space_dict, x_params, target) | |
if len(x) == 6: | |
return fig | |
# Case of function call from the inverse design module | |
if len(x) == 9: | |
print("Running optimization visualization") | |
# Add traces corresponding to the additional data points | |
df = x[6] | |
# If empty table (when first loading the interface) | |
if df.shape[1] == 3: | |
return fig | |
# Add the values of c_min and c_max to allow to show it in the domain space | |
c_min = x[7] | |
c_max = x[8] | |
df_min = df.copy() | |
df_min["%C"] = c_min | |
df_max = df.copy() | |
df_max["%C"] = c_max | |
df_full = pd.concat([df_min, df_max]) | |
df_filtered, filter_cols = filter_dataframe(x[:6], df_full) | |
trace_name = "Optimization results space" | |
# Case of function call from the property prediction module | |
# For now this only supports the alloy space explored with the August 2023 pilot | |
else: | |
df = pd.DataFrame([x[6:]], columns=["%C", "%Co", "%Cr", "%V", "%Mo", "%W", "Temperature_C"]) | |
df_filtered, filter_cols = filter_dataframe(x[:6], df) | |
trace_name = "Prediction input space" | |
# If no data points matches the selected space | |
if df_filtered.shape[0] == 0: | |
print("No data points matching the selected domain space") | |
return fig | |
cols_for_plot = [c for c in df_synth_pred.columns if c not in filter_cols + ["Temperature_C"]] | |
x_col = cols_for_plot[0] | |
y_col = cols_for_plot[1] | |
z_col = cols_for_plot[2] | |
trace = go.Scatter3d( | |
x=df_filtered[x_col], | |
y=df_filtered[y_col], | |
z=df_filtered[z_col], | |
mode="markers", | |
name=trace_name, | |
hovertemplate=x_col + ": %{x:.2f}<br>" + y_col + ": %{y:.2f}<br>" + z_col + ": %{z:.2f}", | |
) | |
fig.add_trace(trace) | |
return fig | |
def update_figure(x): | |
fig = plot_figure(x) | |
return gr.update(value=fig) | |
return lambda *x: plot_figure(x), lambda *x: update_figure(x) | |
def update_plot(x): | |
fig = create_domain_space(*x) | |
return gr.update(value=fig) | |
def update_dropdown(*x): | |
input_cols = ["%C", "%Co", "%Cr", "%V", "%Mo", "%W", "Temperature_C"] | |
new_input_cols_list = [input_cols.copy(), input_cols.copy(), input_cols.copy()] | |
for i, val in enumerate(x): | |
for j, new_list in enumerate(new_input_cols_list): | |
if j != i: | |
new_list.remove(val) | |
return ( | |
gr.update(choices=new_input_cols_list[0]), | |
gr.update(choices=new_input_cols_list[1]), | |
gr.update(choices=new_input_cols_list[2]), | |
) | |
def on_select(evt: gr.SelectData): # SelectData is a subclass of EventData | |
print("_________________________________") | |
print(f"You selected {evt.value} at {evt.index} from {evt.target}") | |
return | |
def create_slicer_update(space_dict): | |
def update_slicer(x): | |
return gr.update( | |
label=x, | |
value=space_dict[x]["value"], | |
minimum=space_dict[x]["min"], | |
maximum=space_dict[x]["max"], | |
step=space_dict[x]["step_display"], | |
) | |
return lambda x: update_slicer(x) | |
def create_gradio(plot_fn, update_plot_fn, update_slider_fn): | |
""" | |
To test the domain space exploration locally | |
""" | |
# css_styling, osium_theme = load_theme() | |
page_title = "Visualize your design space" | |
input_cols = ["%C", "%Co", "%Cr", "%V", "%Mo", "%W", "Temperature_C"] | |
with gr.Blocks() as demo: | |
gr.Markdown(f"# <p style='text-align: center;'>Adapt your AI models</p>") | |
gr.Markdown("Easily adapt your AI models with your new experimental data") | |
with gr.Row(): | |
train_button = gr.Button() | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("### Your input files") | |
elem1 = "%Cr" | |
elem2 = "%V" | |
elem3 = "%Mo" | |
with gr.Row(): | |
input_list1 = input_cols.copy() | |
input_list1.remove(elem2) | |
input_list1.remove(elem3) | |
dropdown_1 = gr.Dropdown(label="Fix element 1", choices=input_list1, value=elem1) | |
input_slicer_1 = gr.Slider( | |
label=elem1, | |
minimum=space_dict[elem1]["min"], | |
maximum=space_dict[elem1]["max"], | |
value=space_dict[elem1]["value"], | |
step=space_dict[elem1]["step_display"], | |
) | |
with gr.Row(): | |
input_list2 = input_cols.copy() | |
input_list2.remove(elem1) | |
input_list2.remove(elem3) | |
dropdown_2 = gr.Dropdown(label="Fix element 2", choices=input_list2, value=elem2) | |
input_slicer_2 = gr.Slider( | |
label=elem2, | |
minimum=space_dict[elem2]["min"], | |
maximum=space_dict[elem2]["max"], | |
value=space_dict[elem2]["value"], | |
step=space_dict[elem2]["step_display"], | |
) | |
with gr.Row(): | |
input_list3 = input_cols.copy() | |
input_list3.remove(elem1) | |
input_list3.remove(elem2) | |
dropdown_3 = gr.Dropdown(label="Fix element 3", choices=input_list3, value=elem3) | |
input_slicer_3 = gr.Slider( | |
label=elem3, | |
minimum=space_dict[elem3]["min"], | |
maximum=space_dict[elem3]["max"], | |
value=space_dict[elem3]["value"], | |
step=space_dict[elem3]["step_display"], | |
) | |
with gr.Column(): | |
gr.Markdown("### Your model adaptation") | |
output_plot = gr.Plot(type="plotly") | |
train_button.click( | |
fn=plot_fn, | |
inputs=[dropdown_1, input_slicer_1, dropdown_2, input_slicer_2, dropdown_3, input_slicer_3], | |
outputs=[output_plot], | |
show_progress=True, | |
) | |
input_slicer_1.change( | |
fn=update_plot_fn, | |
inputs=[dropdown_1, input_slicer_1, dropdown_2, input_slicer_2, dropdown_3, input_slicer_3], | |
outputs=[output_plot], | |
show_progress=True, | |
queue=True, | |
every=0.5, | |
) | |
input_slicer_2.change( | |
fn=update_plot_fn, | |
inputs=[dropdown_1, input_slicer_1, dropdown_2, input_slicer_2, dropdown_3, input_slicer_3], | |
outputs=[output_plot], | |
show_progress=True, | |
queue=True, | |
# every=2, | |
) | |
input_slicer_3.change( | |
fn=update_plot_fn, | |
inputs=[dropdown_1, input_slicer_1, dropdown_2, input_slicer_2, dropdown_3, input_slicer_3], | |
outputs=[output_plot], | |
show_progress=True, | |
queue=True, | |
# every=2, | |
) | |
# Update the choices in the dropdown based on the elements selected | |
# dropdown_1.change(fn=update_dropdown, inputs=[dropdown_1], outputs=[dropdown_2, dropdown_3], show_progress=True) | |
# dropdown_2.change(fn=update_dropdown, inputs=[dropdown_2], outputs=[dropdown_1, dropdown_3], show_progress=True) | |
# dropdown_2.change(fn=update_dropdown, inputs=[dropdown_3], outputs=[dropdown_1, dropdown_2], show_progress=True) | |
dropdown_1.change( | |
fn=update_dropdown, | |
inputs=[dropdown_1, dropdown_2, dropdown_3], | |
outputs=[dropdown_1, dropdown_2, dropdown_3], | |
show_progress=True, | |
) | |
dropdown_2.change( | |
fn=update_dropdown, | |
inputs=[dropdown_1, dropdown_2, dropdown_3], | |
outputs=[dropdown_1, dropdown_2, dropdown_3], | |
show_progress=True, | |
) | |
dropdown_3.change( | |
fn=update_dropdown, | |
inputs=[dropdown_1, dropdown_2, dropdown_3], | |
outputs=[dropdown_1, dropdown_2, dropdown_3], | |
show_progress=True, | |
) | |
# Update the slider name based on the choice of the dropdow | |
dropdown_1.change(fn=update_slider_fn, inputs=[dropdown_1], outputs=[input_slicer_1]) | |
dropdown_2.change(fn=update_slider_fn, inputs=[dropdown_2], outputs=[input_slicer_2]) | |
dropdown_3.change(fn=update_slider_fn, inputs=[dropdown_3], outputs=[input_slicer_3]) | |
# input_slicer_1.select(on_select, None, None) | |
return demo | |
if __name__ == "__main__": | |
with open("./conf_test_uncertainty.yaml", "rb") as file: | |
conf = yaml.safe_load(file) | |
space_dict = conf["domain_space"]["uncertainty_space_dict"] | |
explored_dict = conf["domain_space"]["explored_space_dict"] | |
# df_synth = create_domain_space(space_dict, conf["inference"], df_path=conf["domain_space"]["design_space_path"]) | |
df_synth = load_domain_space(conf["domain_space"]["design_space_path"]) | |
plot_fn, update_plot_fn = create_plot(df_synth, explored_dict) | |
update_slicer_fn = create_slicer_update(space_dict) | |
demo = create_gradio(plot_fn, update_plot_fn, update_slicer_fn) | |
demo.launch(enable_queue=True) | |