rahideer commited on
Commit
f0cab08
·
verified ·
1 Parent(s): e20efc8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -16
app.py CHANGED
@@ -1,23 +1,41 @@
1
  import streamlit as st
2
- from rag_pipeline import RAGPipeline
 
 
3
 
4
- st.set_page_config(page_title="Medical QA Assistant", page_icon="🩺")
5
- st.title("🩺 Medical QA Assistant")
6
- st.markdown("Ask any medical question and get evidence-based answers from PubMed.")
7
 
8
- @st.cache_resource
9
- def load_rag():
10
- return rag = RAGPipeline()
 
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- rag = load_rag()
 
 
14
 
15
- question = st.text_input("Enter your medical question:")
 
16
 
17
- if st.button("Get Answer"):
18
- if question:
19
- with st.spinner("Searching PubMed and generating answer..."):
20
- answer = rag.generate_answer(question)
21
- st.success(answer)
22
- else:
23
- st.warning("Please enter a question.")
 
1
  import streamlit as st
2
+ from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
3
+ from datasets import load_dataset
4
+ import torch
5
 
6
+ # Load the dataset
7
+ dataset = load_dataset("pubmed_qa", split="test")
 
8
 
9
+ # Initialize RAG components
10
+ tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
11
+ retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="default", use_dummy_dataset=True)
12
+ model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq")
13
 
14
+ # Function to get the answer to a medical query
15
+ def get_medical_answer(query):
16
+ # Encode the query to retrieve relevant documents
17
+ inputs = tokenizer(query, return_tensors="pt")
18
+ input_ids = inputs["input_ids"]
19
+
20
+ # Retrieve relevant documents
21
+ docs = retriever(input_ids=input_ids, return_tensors="pt")
22
+
23
+ # Generate the answer from the model
24
+ generated_ids = model.generate(input_ids=input_ids, context_input_ids=docs["context_input_ids"],
25
+ context_attention_mask=docs["context_attention_mask"])
26
+
27
+ # Decode the generated answer
28
+ generated_answer = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
29
+ return generated_answer
30
 
31
+ # Streamlit UI
32
+ st.title("Medical QA Assistant")
33
+ st.write("Ask any medical question, and I will answer it based on PubMed papers!")
34
 
35
+ # Input text box for queries
36
+ query = st.text_input("Enter your medical question:")
37
 
38
+ if query:
39
+ with st.spinner("Searching for the answer..."):
40
+ answer = get_medical_answer(query)
41
+ st.write(f"Answer: {answer}")