bndl commited on
Commit
703aa87
·
1 Parent(s): 20d816d

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +7 -1
utils.py CHANGED
@@ -76,8 +76,14 @@ def encode_and_predict(model_path, data, one_hot_scaler, minmax_scaler_inputs, m
76
  else:
77
  return model.predict(data)
78
 
79
- def predict(model_path, data, explainer=None):
80
  model = tf.keras.models.load_model(model_path)
 
 
 
 
 
 
81
  if explainer:
82
  return model.predict(data), data.columns, explainer.shap_values(data[-10:])
83
  else:
 
76
  else:
77
  return model.predict(data)
78
 
79
+ def predict(model_path, data, explainer=None, df_train=None):
80
  model = tf.keras.models.load_model(model_path)
81
+
82
+ if df_train is not None:
83
+
84
+ explainer = shap.KernelExplainer(model.predict, df_train[:10])
85
+ return model.predict(data), data.columns, explainer.shap_values(data[-10:])
86
+
87
  if explainer:
88
  return model.predict(data), data.columns, explainer.shap_values(data[-10:])
89
  else: