|
import streamlit as st |
|
import gradio as gr |
|
import shap |
|
import torch |
|
import tensorflow as tf |
|
from transformers import RobertaTokenizer, RobertaModel |
|
from transformers import AutoModelForSequenceClassification |
|
from transformers import TFAutoModelForSequenceClassification |
|
from transformers import AutoTokenizer |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("paragon-analytics/ADRv1") |
|
model = AutoModelForSequenceClassification.from_pretrained("paragon-analytics/ADRv1") |
|
|
|
def adr_predict(x): |
|
encoded_input = tokenizer(x, return_tensors='pt') |
|
output = model(**encoded_input) |
|
scores = output[0][0].detach().numpy() |
|
scores = tf.nn.softmax(scores) |
|
|
|
|
|
pred = transformers.pipeline("text-classification", model=model, |
|
tokenizer=tokenizer, device=0, return_all_scores=True) |
|
explainer = shap.Explainer(pred) |
|
shap_values = explainer([x]) |
|
shap_plot = shap.plots.text(shap_values) |
|
return {"Severe Reaction": float(scores.numpy()[1]), "Non-severe Reaction": float(scores.numpy()[0])}, shap_plot |
|
|
|
def main(text): |
|
text = str(text).lower() |
|
obj = adr_predict(text) |
|
return obj[0],obj[1] |
|
|
|
title = "Welcome to **ADR Detector** 🪐" |
|
description1 = """ |
|
This app takes text (up to a few sentences) and predicts to what extent the text describes severe (or non-severe) adverse reaction to medicaitons. |
|
""" |
|
|
|
with gr.Blocks(title=title) as demo: |
|
gr.Markdown(f"## {title}") |
|
gr.Markdown(description1) |
|
gr.Markdown("""---""") |
|
text = gr.Textbox(label="Enter Your Text Here:",lines=2, placeholder="Type it here ...") |
|
submit_btn = gr.Button("Analyze") |
|
|
|
|
|
with gr.Column(visible=True) as output_col: |
|
label = gr.Label(label = "Predicted Label") |
|
|
|
|
|
|
|
|
|
shap_plot = gr.HighlightedText(label="Word Scores",combine_adjacent=False) |
|
|
|
submit_btn.click( |
|
main, |
|
[text], |
|
[label,shap_plot], api_name="adr" |
|
) |
|
|
|
gr.Markdown("### Click on any of the examples below to see to what extent they contain resilience messaging:") |
|
gr.Examples([["I have minor pain."],["I have severe pain."]], [text], [label,shap_plot], main, cache_examples=True) |
|
|
|
demo.launch() |
|
|