Sanchit2207 commited on
Commit
b3ae3fe
·
verified ·
1 Parent(s): f0c301d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -7
app.py CHANGED
@@ -1,28 +1,36 @@
1
 
2
  from transformers import AutoTokenizer, AutoModelForQuestionAnswering
3
  import torch
 
4
 
5
  tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
6
  model = AutoModelForQuestionAnswering.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
7
 
8
  def answer_question(context, question):
9
  inputs = tokenizer.encode_plus(question, context, return_tensors="pt")
10
- answer_start_scores, answer_end_scores = model(**inputs).values()
 
 
 
 
 
11
 
12
- start = torch.argmax(answer_start_scores)
13
- end = torch.argmax(answer_end_scores) + 1
14
 
15
  answer = tokenizer.convert_tokens_to_string(
16
  tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][start:end])
17
  )
18
  return answer
19
- import gradio as gr
20
 
21
  def chatbot_response(question):
22
- context = "Include a reliable medical knowledge base here like PubMed abstracts or simplified texts."
 
 
 
 
23
  return answer_question(context, question)
24
 
25
  iface = gr.Interface(fn=chatbot_response, inputs="text", outputs="text", title="Medical Chatbot")
26
  iface.launch()
27
-
28
-
 
1
 
2
  from transformers import AutoTokenizer, AutoModelForQuestionAnswering
3
  import torch
4
+ import gradio as gr
5
 
6
  tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
7
  model = AutoModelForQuestionAnswering.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
8
 
9
  def answer_question(context, question):
10
  inputs = tokenizer.encode_plus(question, context, return_tensors="pt")
11
+ outputs = model(**inputs)
12
+ start_scores = outputs.start_logits
13
+ end_scores = outputs.end_logits
14
+
15
+ start = torch.argmax(start_scores)
16
+ end = torch.argmax(end_scores) + 1
17
 
18
+ if start >= end:
19
+ return "I couldn't find an answer."
20
 
21
  answer = tokenizer.convert_tokens_to_string(
22
  tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][start:end])
23
  )
24
  return answer
25
+
26
 
27
  def chatbot_response(question):
28
+ context = (
29
+ "COVID-19 is a respiratory illness caused by the SARS-CoV-2 virus. "
30
+ "Common symptoms include fever, cough, fatigue, and loss of taste or smell. "
31
+ "Fever usually lasts for 3-5 days. Treatment is mostly supportive, and vaccination reduces severity."
32
+ )
33
  return answer_question(context, question)
34
 
35
  iface = gr.Interface(fn=chatbot_response, inputs="text", outputs="text", title="Medical Chatbot")
36
  iface.launch()