Spaces:
Sleeping
Sleeping
Create domain_space.py
Browse files- domain_space.py +504 -0
domain_space.py
ADDED
@@ -0,0 +1,504 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import pandas as pd
|
3 |
+
import os
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
|
6 |
+
from utils import unpickle_file, scale_numerical_w_missing
|
7 |
+
import plotly.express as px
|
8 |
+
from gradio_utils import load_theme
|
9 |
+
from alloy_data_preprocessing import add_physics_features
|
10 |
+
from inference_model_main import predict_all_results
|
11 |
+
import plotly.graph_objects as go
|
12 |
+
import yaml
|
13 |
+
import numpy as np
|
14 |
+
|
15 |
+
|
16 |
+
def run_predictions(
|
17 |
+
df,
|
18 |
+
scaler_inputs_path,
|
19 |
+
main_model_path,
|
20 |
+
main_input_cols_order,
|
21 |
+
intermediate_model_path,
|
22 |
+
intermediate_results_columns,
|
23 |
+
):
|
24 |
+
"""
|
25 |
+
Scale the data and runs the predictions on the intermediate columns and the final properties
|
26 |
+
"""
|
27 |
+
scaler_inputs = unpickle_file(scaler_inputs_path)
|
28 |
+
df_p = add_physics_features(df)
|
29 |
+
df_scaled = scale_numerical_w_missing(df_p, scaler_inputs.feature_names_in_, scaler_inputs)
|
30 |
+
|
31 |
+
y_pred, uncertainty = predict_all_results(
|
32 |
+
df_scaled,
|
33 |
+
main_model_path,
|
34 |
+
main_input_cols_order,
|
35 |
+
scaler_targets_main=None,
|
36 |
+
intermediate_model_path=intermediate_model_path,
|
37 |
+
intermediate_results_columns=intermediate_results_columns,
|
38 |
+
return_uncertainty=True,
|
39 |
+
uncertainty_type="weighted",
|
40 |
+
)
|
41 |
+
|
42 |
+
return y_pred, uncertainty
|
43 |
+
|
44 |
+
|
45 |
+
def create_domain_space(space_dict, inference_dict, df_path):
|
46 |
+
"""
|
47 |
+
Create the dataframe containing the pre-computed values for the uncertainty
|
48 |
+
"""
|
49 |
+
input_cols = ["%C", "%Co", "%Cr", "%V", "%Mo", "%W", "Temperature_C"]
|
50 |
+
|
51 |
+
c = space_dict["%C"]["value"]
|
52 |
+
co = space_dict["%Co"]["value"]
|
53 |
+
cr = space_dict["%Cr"]["value"]
|
54 |
+
v = space_dict["%V"]["value"]
|
55 |
+
mo = space_dict["%Mo"]["value"]
|
56 |
+
w = space_dict["%W"]["value"]
|
57 |
+
temp = 538
|
58 |
+
space_list = [
|
59 |
+
[ic, ico, icr, iv, imo, iw, temp]
|
60 |
+
for ic in np.arange(
|
61 |
+
space_dict["%C"]["min"], space_dict["%C"]["max"] + space_dict["%C"]["step"], space_dict["%C"]["step"]
|
62 |
+
)
|
63 |
+
for ico in np.arange(
|
64 |
+
space_dict["%Co"]["min"], space_dict["%Co"]["max"] + space_dict["%Co"]["step"], space_dict["%Co"]["step"]
|
65 |
+
)
|
66 |
+
for icr in np.arange(
|
67 |
+
space_dict["%Cr"]["min"], space_dict["%Cr"]["max"] + space_dict["%Cr"]["step"], space_dict["%Cr"]["step"]
|
68 |
+
)
|
69 |
+
for iv in np.arange(
|
70 |
+
space_dict["%V"]["min"], space_dict["%V"]["max"] + space_dict["%V"]["step"], space_dict["%V"]["step"]
|
71 |
+
)
|
72 |
+
for imo in np.arange(
|
73 |
+
space_dict["%Mo"]["min"], space_dict["%Mo"]["max"] + space_dict["%Mo"]["step"], space_dict["%Mo"]["step"]
|
74 |
+
)
|
75 |
+
for iw in np.arange(
|
76 |
+
space_dict["%W"]["min"], space_dict["%W"]["max"] + space_dict["%W"]["step"], space_dict["%W"]["step"]
|
77 |
+
)
|
78 |
+
]
|
79 |
+
|
80 |
+
df_synth = pd.DataFrame(space_list, columns=input_cols)
|
81 |
+
|
82 |
+
print("Uncertainty space will be computed on:")
|
83 |
+
print(df_synth.shape)
|
84 |
+
|
85 |
+
model_path = inference_dict["final_prediction"]["model_path"]
|
86 |
+
print("Model used:", model_path)
|
87 |
+
scaler_inputs_intermediate = inference_dict["scaler_inputs_path"]
|
88 |
+
intermediate_cols = [
|
89 |
+
"%C matrice",
|
90 |
+
"%Co matrice",
|
91 |
+
"%Cr matrice",
|
92 |
+
"%V matrice",
|
93 |
+
"%Mo matrice",
|
94 |
+
"%W matrice",
|
95 |
+
"M6C",
|
96 |
+
"M23C6",
|
97 |
+
"FCCA1#2",
|
98 |
+
"M2C",
|
99 |
+
"MC - SHP",
|
100 |
+
"MC ETA",
|
101 |
+
]
|
102 |
+
scaler_inputs_main = unpickle_file(inference_dict["final_prediction"]["scaler_inputs_path"])
|
103 |
+
intermediate_model_path_dict = inference_dict["multiple_model_path"]
|
104 |
+
|
105 |
+
y_pred, uncertainty = run_predictions(
|
106 |
+
df_synth,
|
107 |
+
scaler_inputs_intermediate,
|
108 |
+
model_path,
|
109 |
+
scaler_inputs_main.feature_names_in_,
|
110 |
+
intermediate_model_path_dict,
|
111 |
+
intermediate_cols,
|
112 |
+
)
|
113 |
+
df_synth_pred = df_synth.copy()
|
114 |
+
df_synth_pred["y_pred"] = y_pred
|
115 |
+
df_synth_pred["uncertainty_not_scaled"] = uncertainty
|
116 |
+
min_uncertainty, max_uncertainty = (
|
117 |
+
df_synth_pred["uncertainty_not_scaled"].min(),
|
118 |
+
df_synth_pred["uncertainty_not_scaled"].max(),
|
119 |
+
)
|
120 |
+
df_synth_pred["uncertainty"] = (df_synth_pred["uncertainty_not_scaled"] - min_uncertainty) / (
|
121 |
+
max_uncertainty - min_uncertainty
|
122 |
+
)
|
123 |
+
print("Domain space created")
|
124 |
+
|
125 |
+
print("-----------------------------")
|
126 |
+
print("Saving dataframe at", df_path)
|
127 |
+
df_synth_pred.to_csv(df_path, sep=";", index=False)
|
128 |
+
return df_synth_pred
|
129 |
+
|
130 |
+
|
131 |
+
def load_domain_space(df_path):
|
132 |
+
df_synth_pred = pd.read_csv(df_path, sep=";")
|
133 |
+
print("---------------------------")
|
134 |
+
print("min max", df_synth_pred["uncertainty_not_scaled"].min(), df_synth_pred["uncertainty_not_scaled"].max())
|
135 |
+
print("Design space dataframe", df_synth_pred.shape)
|
136 |
+
print("---------------------------")
|
137 |
+
|
138 |
+
return df_synth_pred
|
139 |
+
|
140 |
+
|
141 |
+
def filter_dataframe(params_list, df):
|
142 |
+
col1_name = params_list[0]
|
143 |
+
col1_value = params_list[1]
|
144 |
+
col2_name = params_list[2]
|
145 |
+
col2_value = params_list[3]
|
146 |
+
col3_name = params_list[4]
|
147 |
+
col3_value = params_list[5]
|
148 |
+
|
149 |
+
df_filtered = df[(df[col1_name] == col1_value) & (df[col2_name] == col2_value) & (df[col3_name] == col3_value)]
|
150 |
+
|
151 |
+
return df_filtered, [col1_name, col2_name, col3_name]
|
152 |
+
|
153 |
+
|
154 |
+
def interpolate_space(df, col_name, value):
|
155 |
+
"""
|
156 |
+
Interpolate the uncertainty space for values within the range but not direcly pre-computed
|
157 |
+
"""
|
158 |
+
# No need to interpolate, uncertainty for this value is already pre-computed
|
159 |
+
if value in list(df[col_name]):
|
160 |
+
print("value in column", col_name, value)
|
161 |
+
return df[df[col_name] == value]
|
162 |
+
df_interpolated = df.copy()
|
163 |
+
# Find the closest values in the dataframe to the pass value
|
164 |
+
k_closest = 2
|
165 |
+
df_interpolated["distance"] = np.abs(df[col_name] - value)
|
166 |
+
print("Looking for closest values")
|
167 |
+
values_closest = list(
|
168 |
+
df_interpolated.sort_values(by=["distance"], ascending=True)[col_name].iloc[0:k_closest].values
|
169 |
+
)
|
170 |
+
|
171 |
+
input_cols = ["%C", "%Co", "%Cr", "%V", "%Mo", "%W", "Temperature_C"]
|
172 |
+
agg_cols = input_cols.copy()
|
173 |
+
agg_cols.remove(col_name)
|
174 |
+
print(agg_cols)
|
175 |
+
df_tmp = df[df[col_name].isin(values_closest)]
|
176 |
+
df_tmp = df_tmp.groupby(agg_cols).mean().reset_index().drop(columns=col_name)
|
177 |
+
df_tmp[col_name] = value
|
178 |
+
print("==============")
|
179 |
+
print("Value interpolated", col_name, value)
|
180 |
+
print(df_tmp.shape)
|
181 |
+
return df_tmp
|
182 |
+
|
183 |
+
|
184 |
+
def interpolate_all(params_list, df):
|
185 |
+
print(df.shape)
|
186 |
+
df_filtered = df.copy()
|
187 |
+
filter_cols = []
|
188 |
+
for i in range(0, len(params_list), 2):
|
189 |
+
df_filtered = interpolate_space(df_filtered, params_list[i], params_list[i + 1])
|
190 |
+
filter_cols.append(params_list[i])
|
191 |
+
print(df_filtered.shape)
|
192 |
+
return df_filtered, filter_cols
|
193 |
+
|
194 |
+
|
195 |
+
def make_domain_plot(df_synth_pred, explored_domain_space, x_list):
|
196 |
+
"""
|
197 |
+
Create a plot with the uncertainty space and the training space
|
198 |
+
"""
|
199 |
+
uncertainty_min = df_synth_pred["uncertainty"].min()
|
200 |
+
uncertainty_max = df_synth_pred["uncertainty"].max()
|
201 |
+
|
202 |
+
# df_synth_pred2, filter_cols = filter_dataframe(x_list[:6], df_synth_pred)
|
203 |
+
df_synth_pred2, filter_cols = interpolate_all(x_list[:6], df_synth_pred)
|
204 |
+
|
205 |
+
cols_for_plot = [c for c in df_synth_pred.columns if c not in filter_cols + ["Temperature_C"]]
|
206 |
+
x_col, y_col, z_col = cols_for_plot[0], cols_for_plot[1], cols_for_plot[2]
|
207 |
+
fig = px.scatter_3d(
|
208 |
+
df_synth_pred2,
|
209 |
+
x=x_col,
|
210 |
+
y=y_col,
|
211 |
+
z=z_col,
|
212 |
+
color="uncertainty",
|
213 |
+
range_color=[uncertainty_min, uncertainty_max],
|
214 |
+
hover_data={"uncertainty": ":.3f"},
|
215 |
+
)
|
216 |
+
|
217 |
+
# Filter domain space
|
218 |
+
for i in [0, 2, 4]:
|
219 |
+
if (x_list[i + 1] < explored_domain_space[x_list[i]]["min"]) or (
|
220 |
+
x_list[i + 1] > explored_domain_space[x_list[i]]["max"]
|
221 |
+
):
|
222 |
+
return fig
|
223 |
+
|
224 |
+
# Add explored domain space
|
225 |
+
x_cube = (
|
226 |
+
np.array([0, 0, 1, 1, 0, 0, 1, 1]) * (explored_domain_space[x_col]["max"] - explored_domain_space[x_col]["min"])
|
227 |
+
+ explored_domain_space[x_col]["min"]
|
228 |
+
)
|
229 |
+
y_cube = (
|
230 |
+
np.array([0, 1, 1, 0, 0, 1, 1, 0]) * (explored_domain_space[y_col]["max"] - explored_domain_space[y_col]["min"])
|
231 |
+
+ explored_domain_space[y_col]["min"]
|
232 |
+
)
|
233 |
+
z_cube = (
|
234 |
+
np.array([0, 0, 0, 0, 1, 1, 1, 1]) * (explored_domain_space[z_col]["max"] - explored_domain_space[z_col]["min"])
|
235 |
+
+ explored_domain_space[z_col]["min"]
|
236 |
+
)
|
237 |
+
# Plot domain space as a cube
|
238 |
+
trace4 = go.Mesh3d(
|
239 |
+
# 8 vertices of a cube
|
240 |
+
x=x_cube.tolist(),
|
241 |
+
y=y_cube.tolist(),
|
242 |
+
z=z_cube.tolist(),
|
243 |
+
# Keep these values (i, j, k) to get a cube (represent the vertices)
|
244 |
+
i=[7, 0, 0, 0, 4, 4, 6, 6, 4, 0, 3, 2],
|
245 |
+
j=[3, 4, 1, 2, 5, 6, 5, 2, 0, 1, 6, 3],
|
246 |
+
k=[0, 7, 2, 3, 6, 7, 1, 1, 5, 5, 7, 6],
|
247 |
+
opacity=0.3,
|
248 |
+
color="turquoise",
|
249 |
+
flatshading=True,
|
250 |
+
name="Training space",
|
251 |
+
hovertemplate=x_col + ": %{x:.2f}<br>" + y_col + ": %{y:.2f}<br>" + z_col + ": %{z:.2f}"
|
252 |
+
# vertexcolor=["black"] * 12,
|
253 |
+
)
|
254 |
+
fig.add_trace(trace4)
|
255 |
+
return fig
|
256 |
+
|
257 |
+
|
258 |
+
def create_plot(df_synth_pred, explored_space_dict):
|
259 |
+
"""
|
260 |
+
Wrapper to create the function to generate the plotly plots
|
261 |
+
"""
|
262 |
+
# Create plotly plot
|
263 |
+
|
264 |
+
def plot_figure(x):
|
265 |
+
x_params = x[:6]
|
266 |
+
fig = make_domain_plot(df_synth_pred, explored_space_dict, x_params)
|
267 |
+
if len(x) == 6:
|
268 |
+
return fig
|
269 |
+
|
270 |
+
# Case of function call from the inverse design module
|
271 |
+
if len(x) == 9:
|
272 |
+
print("Running optimization visualization")
|
273 |
+
# Add traces corresponding to the additional data points
|
274 |
+
df = x[6]
|
275 |
+
# If empty table (when first loading the interface)
|
276 |
+
if df.shape[1] == 3:
|
277 |
+
return fig
|
278 |
+
|
279 |
+
# Add the values of c_min and c_max to allow to show it in the domain space
|
280 |
+
c_min = x[7]
|
281 |
+
c_max = x[8]
|
282 |
+
df_min = df.copy()
|
283 |
+
df_min["%C"] = c_min
|
284 |
+
df_max = df.copy()
|
285 |
+
df_max["%C"] = c_max
|
286 |
+
|
287 |
+
df_full = pd.concat([df_min, df_max])
|
288 |
+
|
289 |
+
df_filtered, filter_cols = filter_dataframe(x[:6], df_full)
|
290 |
+
trace_name = "Optimization results space"
|
291 |
+
|
292 |
+
# Case of function call from the property prediction module
|
293 |
+
# For now this only supports the alloy space explored with the August 2023 pilot
|
294 |
+
else:
|
295 |
+
df = pd.DataFrame([x[6:]], columns=["%C", "%Co", "%Cr", "%V", "%Mo", "%W", "Temperature_C"])
|
296 |
+
df_filtered, filter_cols = filter_dataframe(x[:6], df)
|
297 |
+
trace_name = "Prediction input space"
|
298 |
+
|
299 |
+
# If no data points matches the selected space
|
300 |
+
if df_filtered.shape[0] == 0:
|
301 |
+
print("No data points matching the selected domain space")
|
302 |
+
return fig
|
303 |
+
|
304 |
+
cols_for_plot = [c for c in df_synth_pred.columns if c not in filter_cols + ["Temperature_C"]]
|
305 |
+
x_col = cols_for_plot[0]
|
306 |
+
y_col = cols_for_plot[1]
|
307 |
+
z_col = cols_for_plot[2]
|
308 |
+
|
309 |
+
trace = go.Scatter3d(
|
310 |
+
x=df_filtered[x_col],
|
311 |
+
y=df_filtered[y_col],
|
312 |
+
z=df_filtered[z_col],
|
313 |
+
mode="markers",
|
314 |
+
name=trace_name,
|
315 |
+
hovertemplate=x_col + ": %{x:.2f}<br>" + y_col + ": %{y:.2f}<br>" + z_col + ": %{z:.2f}",
|
316 |
+
)
|
317 |
+
fig.add_trace(trace)
|
318 |
+
return fig
|
319 |
+
|
320 |
+
def update_figure(x):
|
321 |
+
fig = plot_figure(x)
|
322 |
+
return gr.update(value=fig)
|
323 |
+
|
324 |
+
return lambda *x: plot_figure(x), lambda *x: update_figure(x)
|
325 |
+
|
326 |
+
|
327 |
+
def update_plot(x):
|
328 |
+
fig = create_domain_space(*x)
|
329 |
+
return gr.update(value=fig)
|
330 |
+
|
331 |
+
|
332 |
+
def update_dropdown(*x):
|
333 |
+
input_cols = ["%C", "%Co", "%Cr", "%V", "%Mo", "%W", "Temperature_C"]
|
334 |
+
new_input_cols_list = [input_cols.copy(), input_cols.copy(), input_cols.copy()]
|
335 |
+
for i, val in enumerate(x):
|
336 |
+
for j, new_list in enumerate(new_input_cols_list):
|
337 |
+
if j != i:
|
338 |
+
new_list.remove(val)
|
339 |
+
return (
|
340 |
+
gr.update(choices=new_input_cols_list[0]),
|
341 |
+
gr.update(choices=new_input_cols_list[1]),
|
342 |
+
gr.update(choices=new_input_cols_list[2]),
|
343 |
+
)
|
344 |
+
|
345 |
+
|
346 |
+
def on_select(evt: gr.SelectData): # SelectData is a subclass of EventData
|
347 |
+
print("_________________________________")
|
348 |
+
print(f"You selected {evt.value} at {evt.index} from {evt.target}")
|
349 |
+
return
|
350 |
+
|
351 |
+
|
352 |
+
def create_slicer_update(space_dict):
|
353 |
+
def update_slicer(x):
|
354 |
+
return gr.update(
|
355 |
+
label=x,
|
356 |
+
value=space_dict[x]["value"],
|
357 |
+
minimum=space_dict[x]["min"],
|
358 |
+
maximum=space_dict[x]["max"],
|
359 |
+
step=space_dict[x]["step_display"],
|
360 |
+
)
|
361 |
+
|
362 |
+
return lambda x: update_slicer(x)
|
363 |
+
|
364 |
+
|
365 |
+
def create_gradio(plot_fn, update_plot_fn, update_slider_fn):
|
366 |
+
"""
|
367 |
+
To test the domain space exploration locally
|
368 |
+
"""
|
369 |
+
# css_styling, osium_theme = load_theme()
|
370 |
+
page_title = "Visualize your design space"
|
371 |
+
|
372 |
+
input_cols = ["%C", "%Co", "%Cr", "%V", "%Mo", "%W", "Temperature_C"]
|
373 |
+
|
374 |
+
with gr.Blocks() as demo:
|
375 |
+
gr.Markdown(f"# <p style='text-align: center;'>Adapt your AI models</p>")
|
376 |
+
gr.Markdown("Easily adapt your AI models with your new experimental data")
|
377 |
+
with gr.Row():
|
378 |
+
train_button = gr.Button()
|
379 |
+
with gr.Row():
|
380 |
+
with gr.Column():
|
381 |
+
gr.Markdown("### Your input files")
|
382 |
+
elem1 = "%Cr"
|
383 |
+
elem2 = "%V"
|
384 |
+
elem3 = "%Mo"
|
385 |
+
with gr.Row():
|
386 |
+
input_list1 = input_cols.copy()
|
387 |
+
input_list1.remove(elem2)
|
388 |
+
input_list1.remove(elem3)
|
389 |
+
dropdown_1 = gr.Dropdown(label="Fix element 1", choices=input_list1, value=elem1)
|
390 |
+
input_slicer_1 = gr.Slider(
|
391 |
+
label=elem1,
|
392 |
+
minimum=space_dict[elem1]["min"],
|
393 |
+
maximum=space_dict[elem1]["max"],
|
394 |
+
value=space_dict[elem1]["value"],
|
395 |
+
step=space_dict[elem1]["step_display"],
|
396 |
+
)
|
397 |
+
with gr.Row():
|
398 |
+
input_list2 = input_cols.copy()
|
399 |
+
input_list2.remove(elem1)
|
400 |
+
input_list2.remove(elem3)
|
401 |
+
dropdown_2 = gr.Dropdown(label="Fix element 2", choices=input_list2, value=elem2)
|
402 |
+
input_slicer_2 = gr.Slider(
|
403 |
+
label=elem2,
|
404 |
+
minimum=space_dict[elem2]["min"],
|
405 |
+
maximum=space_dict[elem2]["max"],
|
406 |
+
value=space_dict[elem2]["value"],
|
407 |
+
step=space_dict[elem2]["step_display"],
|
408 |
+
)
|
409 |
+
with gr.Row():
|
410 |
+
input_list3 = input_cols.copy()
|
411 |
+
input_list3.remove(elem1)
|
412 |
+
input_list3.remove(elem2)
|
413 |
+
dropdown_3 = gr.Dropdown(label="Fix element 3", choices=input_list3, value=elem3)
|
414 |
+
input_slicer_3 = gr.Slider(
|
415 |
+
label=elem3,
|
416 |
+
minimum=space_dict[elem3]["min"],
|
417 |
+
maximum=space_dict[elem3]["max"],
|
418 |
+
value=space_dict[elem3]["value"],
|
419 |
+
step=space_dict[elem3]["step_display"],
|
420 |
+
)
|
421 |
+
|
422 |
+
with gr.Column():
|
423 |
+
gr.Markdown("### Your model adaptation")
|
424 |
+
output_plot = gr.Plot(type="plotly")
|
425 |
+
|
426 |
+
train_button.click(
|
427 |
+
fn=plot_fn,
|
428 |
+
inputs=[dropdown_1, input_slicer_1, dropdown_2, input_slicer_2, dropdown_3, input_slicer_3],
|
429 |
+
outputs=[output_plot],
|
430 |
+
show_progress=True,
|
431 |
+
)
|
432 |
+
|
433 |
+
input_slicer_1.change(
|
434 |
+
fn=update_plot_fn,
|
435 |
+
inputs=[dropdown_1, input_slicer_1, dropdown_2, input_slicer_2, dropdown_3, input_slicer_3],
|
436 |
+
outputs=[output_plot],
|
437 |
+
show_progress=True,
|
438 |
+
queue=True,
|
439 |
+
every=0.5,
|
440 |
+
)
|
441 |
+
|
442 |
+
input_slicer_2.change(
|
443 |
+
fn=update_plot_fn,
|
444 |
+
inputs=[dropdown_1, input_slicer_1, dropdown_2, input_slicer_2, dropdown_3, input_slicer_3],
|
445 |
+
outputs=[output_plot],
|
446 |
+
show_progress=True,
|
447 |
+
queue=True,
|
448 |
+
# every=2,
|
449 |
+
)
|
450 |
+
|
451 |
+
input_slicer_3.change(
|
452 |
+
fn=update_plot_fn,
|
453 |
+
inputs=[dropdown_1, input_slicer_1, dropdown_2, input_slicer_2, dropdown_3, input_slicer_3],
|
454 |
+
outputs=[output_plot],
|
455 |
+
show_progress=True,
|
456 |
+
queue=True,
|
457 |
+
# every=2,
|
458 |
+
)
|
459 |
+
|
460 |
+
# Update the choices in the dropdown based on the elements selected
|
461 |
+
# dropdown_1.change(fn=update_dropdown, inputs=[dropdown_1], outputs=[dropdown_2, dropdown_3], show_progress=True)
|
462 |
+
# dropdown_2.change(fn=update_dropdown, inputs=[dropdown_2], outputs=[dropdown_1, dropdown_3], show_progress=True)
|
463 |
+
# dropdown_2.change(fn=update_dropdown, inputs=[dropdown_3], outputs=[dropdown_1, dropdown_2], show_progress=True)
|
464 |
+
dropdown_1.change(
|
465 |
+
fn=update_dropdown,
|
466 |
+
inputs=[dropdown_1, dropdown_2, dropdown_3],
|
467 |
+
outputs=[dropdown_1, dropdown_2, dropdown_3],
|
468 |
+
show_progress=True,
|
469 |
+
)
|
470 |
+
dropdown_2.change(
|
471 |
+
fn=update_dropdown,
|
472 |
+
inputs=[dropdown_1, dropdown_2, dropdown_3],
|
473 |
+
outputs=[dropdown_1, dropdown_2, dropdown_3],
|
474 |
+
show_progress=True,
|
475 |
+
)
|
476 |
+
dropdown_3.change(
|
477 |
+
fn=update_dropdown,
|
478 |
+
inputs=[dropdown_1, dropdown_2, dropdown_3],
|
479 |
+
outputs=[dropdown_1, dropdown_2, dropdown_3],
|
480 |
+
show_progress=True,
|
481 |
+
)
|
482 |
+
|
483 |
+
# Update the slider name based on the choice of the dropdow
|
484 |
+
dropdown_1.change(fn=update_slider_fn, inputs=[dropdown_1], outputs=[input_slicer_1])
|
485 |
+
dropdown_2.change(fn=update_slider_fn, inputs=[dropdown_2], outputs=[input_slicer_2])
|
486 |
+
dropdown_3.change(fn=update_slider_fn, inputs=[dropdown_3], outputs=[input_slicer_3])
|
487 |
+
|
488 |
+
# input_slicer_1.select(on_select, None, None)
|
489 |
+
return demo
|
490 |
+
|
491 |
+
|
492 |
+
if __name__ == "__main__":
|
493 |
+
with open("./conf_test_uncertainty.yaml", "rb") as file:
|
494 |
+
conf = yaml.safe_load(file)
|
495 |
+
space_dict = conf["domain_space"]["uncertainty_space_dict"]
|
496 |
+
explored_dict = conf["domain_space"]["explored_space_dict"]
|
497 |
+
|
498 |
+
# df_synth = create_domain_space(space_dict, conf["inference"], df_path=conf["domain_space"]["design_space_path"])
|
499 |
+
df_synth = load_domain_space(conf["domain_space"]["design_space_path"])
|
500 |
+
|
501 |
+
plot_fn, update_plot_fn = create_plot(df_synth, explored_dict)
|
502 |
+
update_slicer_fn = create_slicer_update(space_dict)
|
503 |
+
demo = create_gradio(plot_fn, update_plot_fn, update_slicer_fn)
|
504 |
+
demo.launch(enable_queue=True)
|