Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -53,16 +53,14 @@ def adr_predict(x):
|
|
| 53 |
scores = output[0][0].detach()
|
| 54 |
scores = torch.nn.functional.softmax(scores)
|
| 55 |
|
| 56 |
-
shap_values = explainer([str(
|
| 57 |
# # Find the index of the class you want as the default reference (e.g., 'label_1')
|
| 58 |
# label_1_index = np.where(np.array(explainer.output_names) == 'label_1')[0][0]
|
| 59 |
-
label_1_index = np.where(np.array(explainer.output_names) == 'label_1')[0][0]
|
| 60 |
|
| 61 |
# # Plot the SHAP values for a specific instance in your dataset (e.g., instance 0)
|
| 62 |
# shap.plots.text(shap_values[label_1_index][0])
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
local_plot = shap.plots.text(shap_values[1], display=False)
|
| 66 |
|
| 67 |
# med = med_score(classifier(x+str(", There is a medication."))[0])
|
| 68 |
# sym = sym_score(classifier(x+str(", There is a symptom."))[0])
|
|
|
|
| 53 |
scores = output[0][0].detach()
|
| 54 |
scores = torch.nn.functional.softmax(scores)
|
| 55 |
|
| 56 |
+
shap_values = explainer([str(x).lower()])
|
| 57 |
# # Find the index of the class you want as the default reference (e.g., 'label_1')
|
| 58 |
# label_1_index = np.where(np.array(explainer.output_names) == 'label_1')[0][0]
|
|
|
|
| 59 |
|
| 60 |
# # Plot the SHAP values for a specific instance in your dataset (e.g., instance 0)
|
| 61 |
# shap.plots.text(shap_values[label_1_index][0])
|
| 62 |
+
|
| 63 |
+
local_plot = shap.plots.text(shap_values[0], display=False)
|
|
|
|
| 64 |
|
| 65 |
# med = med_score(classifier(x+str(", There is a medication."))[0])
|
| 66 |
# sym = sym_score(classifier(x+str(", There is a symptom."))[0])
|