snajmark commited on
Commit
a3d7967
·
1 Parent(s): b56c937

Update domain_space.py

Browse files
Files changed (1) hide show
  1. domain_space.py +7 -7
domain_space.py CHANGED
@@ -193,12 +193,12 @@ def interpolate_all(params_list, df):
193
  return df_filtered, filter_cols
194
 
195
 
196
- def make_domain_plot(df_synth_pred, explored_domain_space, x_list):
197
  """
198
  Create a plot with the uncertainty space and the training space
199
  """
200
- uncertainty_min = df_synth_pred["y_pred"].min()
201
- uncertainty_max = df_synth_pred["y_pred"].max()
202
 
203
  # df_synth_pred2, filter_cols = filter_dataframe(x_list[:6], df_synth_pred)
204
  df_synth_pred2, filter_cols = interpolate_all(x_list[:6], df_synth_pred)
@@ -210,9 +210,9 @@ def make_domain_plot(df_synth_pred, explored_domain_space, x_list):
210
  x=x_col,
211
  y=y_col,
212
  z=z_col,
213
- color="y_pred",
214
  range_color=[uncertainty_min, uncertainty_max],
215
- hover_data={"y_pred": ":.3f"},
216
  )
217
 
218
  # Filter domain space
@@ -256,7 +256,7 @@ def make_domain_plot(df_synth_pred, explored_domain_space, x_list):
256
  return fig
257
 
258
 
259
- def create_plot(df_synth_pred, explored_space_dict):
260
  """
261
  Wrapper to create the function to generate the plotly plots
262
  """
@@ -264,7 +264,7 @@ def create_plot(df_synth_pred, explored_space_dict):
264
 
265
  def plot_figure(x):
266
  x_params = x[:6]
267
- fig = make_domain_plot(df_synth_pred, explored_space_dict, x_params)
268
  if len(x) == 6:
269
  return fig
270
 
 
193
  return df_filtered, filter_cols
194
 
195
 
196
+ def make_domain_plot(df_synth_pred, explored_domain_space, x_list, target="y_pred"):
197
  """
198
  Create a plot with the uncertainty space and the training space
199
  """
200
+ uncertainty_min = df_synth_pred[target].min()
201
+ uncertainty_max = df_synth_pred[target].max()
202
 
203
  # df_synth_pred2, filter_cols = filter_dataframe(x_list[:6], df_synth_pred)
204
  df_synth_pred2, filter_cols = interpolate_all(x_list[:6], df_synth_pred)
 
210
  x=x_col,
211
  y=y_col,
212
  z=z_col,
213
+ color= target,
214
  range_color=[uncertainty_min, uncertainty_max],
215
+ hover_data={target: ":.3f"},
216
  )
217
 
218
  # Filter domain space
 
256
  return fig
257
 
258
 
259
+ def create_plot(df_synth_pred, explored_space_dict, target):
260
  """
261
  Wrapper to create the function to generate the plotly plots
262
  """
 
264
 
265
  def plot_figure(x):
266
  x_params = x[:6]
267
+ fig = make_domain_plot(df_synth_pred, explored_space_dict, x_params, target)
268
  if len(x) == 6:
269
  return fig
270