bndl commited on
Commit
6d548db
·
1 Parent(s): 325df2e

Update template_gradio_interface.py

Browse files
Files changed (1) hide show
  1. template_gradio_interface.py +10 -1
template_gradio_interface.py CHANGED
@@ -52,7 +52,16 @@ def call_predict(inference_dict, cols_order):
52
  y_pred_rescaled = scaler_targets.inverse_transform(y_pred)
53
 
54
  plt.clf()
55
- shap.summary_plot(shap_values[0], feature_names=df_preprocessed.columns)
 
 
 
 
 
 
 
 
 
56
  fig = plt.gcf()
57
 
58
  print("mmmmmmmmmmmmmmmmmmmmm")
 
52
  y_pred_rescaled = scaler_targets.inverse_transform(y_pred)
53
 
54
  plt.clf()
55
+ # shap.summary_plot(shap_values[0], feature_names=df_preprocessed.columns)
56
+ plt.figure(figsize=(15, 15))
57
+ plt.subplot(1,2,1)
58
+ shap.summary_plot(shap_values[0], input, show=False, feature_names=df_preprocessed.columns, plot_size=(15, 15))
59
+ plt.subplot(1,2,2)
60
+ shap.summary_plot(shap_values[1], input, show=False, feature_names=df_preprocessed.columns, plot_size=(15, 15))
61
+ plt.subplot(1,2,3)
62
+ shap.summary_plot(shap_values[2], input, show=False, feature_names=df_preprocessed.columns, plot_size=(15, 15))
63
+ plt.tight_layout()
64
+ plt.subplots_adjust(wspace=2.0)
65
  fig = plt.gcf()
66
 
67
  print("mmmmmmmmmmmmmmmmmmmmm")