File size: 1,276 Bytes
224cd33
d182e77
 
b3ae3fe
f68bac7
d182e77
 
4372d3f
d182e77
 
b3ae3fe
 
 
 
 
 
4372d3f
b3ae3fe
 
4372d3f
d182e77
 
 
 
b3ae3fe
f0c301d
 
b3ae3fe
 
 
 
 
f0c301d
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37

from transformers import AutoTokenizer, AutoModelForQuestionAnswering
import torch
import gradio as gr

tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
model = AutoModelForQuestionAnswering.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")

def answer_question(context, question):
    inputs = tokenizer.encode_plus(question, context, return_tensors="pt")
    outputs = model(**inputs)
    start_scores = outputs.start_logits
    end_scores = outputs.end_logits

    start = torch.argmax(start_scores)
    end = torch.argmax(end_scores) + 1

    if start >= end:
        return "I couldn't find an answer."

    answer = tokenizer.convert_tokens_to_string(
        tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][start:end])
    )
    return answer


def chatbot_response(question):
    context = (
        "COVID-19 is a respiratory illness caused by the SARS-CoV-2 virus. "
        "Common symptoms include fever, cough, fatigue, and loss of taste or smell. "
        "Fever usually lasts for 3-5 days. Treatment is mostly supportive, and vaccination reduces severity."
    )
    return answer_question(context, question)

iface = gr.Interface(fn=chatbot_response, inputs="text", outputs="text", title="Medical Chatbot")
iface.launch()