|
import streamlit as st |
|
import gradio as gr |
|
import shap |
|
import numpy as np |
|
import scipy as sp |
|
import torch |
|
import tensorflow as tf |
|
import transformers |
|
from transformers import pipeline |
|
from transformers import RobertaTokenizer, RobertaModel |
|
from transformers import AutoModelForSequenceClassification |
|
from transformers import TFAutoModelForSequenceClassification |
|
from transformers import AutoTokenizer |
|
import matplotlib.pyplot as plt |
|
import sys |
|
import csv |
|
|
|
csv.field_size_limit(sys.maxsize) |
|
|
|
device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("paragon-analytics/ADRv1") |
|
model = AutoModelForSequenceClassification.from_pretrained("paragon-analytics/ADRv1").to(device) |
|
|
|
|
|
pred = transformers.pipeline("text-classification", model=model, |
|
tokenizer=tokenizer, return_all_scores=True) |
|
|
|
explainer = shap.Explainer(pred) |
|
|
|
|
|
classifier = transformers.pipeline("text-classification", model = "cross-encoder/qnli-electra-base") |
|
|
|
def med_score(x): |
|
label = x['label'] |
|
score_1 = x['score'] |
|
return round(score_1,3) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def adr_predict(x): |
|
encoded_input = tokenizer(x, return_tensors='pt') |
|
output = model(**encoded_input) |
|
scores = output[0][0].detach().numpy() |
|
scores = tf.nn.softmax(scores) |
|
|
|
shap_values = explainer([str(x).lower()]) |
|
local_plot = shap.plots.text(shap_values[0], display=False) |
|
|
|
med = med_score(classifier(x+str(", There is a medication."))[0]) |
|
|
|
|
|
return {"Severe Reaction": float(scores.numpy()[1]), "Non-severe Reaction": float(scores.numpy()[0])}, local_plot, {"Contains Medication": float(med), "No Medications": float(1-med)} |
|
|
|
|
|
|
|
def main(prob1): |
|
text = str(prob1).lower() |
|
obj = adr_predict(text) |
|
return obj[0],obj[1],obj[2] |
|
|
|
|
|
title = "Welcome to **ADR Detector** 🪐" |
|
description1 = """This app takes text (up to a few sentences) and predicts to what extent the text describes severe (or non-severe) adverse reaction to medicaitons. Please do NOT use for medical diagnosis.""" |
|
|
|
with gr.Blocks(title=title) as demo: |
|
gr.Markdown(f"## {title}") |
|
gr.Markdown(description1) |
|
gr.Markdown("""---""") |
|
prob1 = gr.Textbox(label="Enter Your Text Here:",lines=2, placeholder="Type it here ...") |
|
submit_btn = gr.Button("Analyze") |
|
|
|
with gr.Row(): |
|
|
|
with gr.Column(visible=True) as output_col: |
|
label = gr.Label(label = "Predicted Label") |
|
local_plot = gr.HTML(label = 'Shap:') |
|
|
|
with gr.Column(visible=True) as output_col: |
|
med = gr.Label(label = "Contains Medication") |
|
|
|
|
|
submit_btn.click( |
|
main, |
|
[prob1], |
|
[label |
|
,local_plot, med |
|
|
|
], api_name="adr" |
|
) |
|
|
|
with gr.Row(): |
|
gr.Markdown("### Click on any of the examples below to see how it works:") |
|
gr.Examples([["I had severe headache after taking Aspirin."],["I had minor stomachache after taking Acetaminophen."]], [prob1], [label,local_plot, med |
|
|
|
], main, cache_examples=True) |
|
|
|
demo.launch() |
|
|