rahideer commited on
Commit
c54ca35
·
verified ·
1 Parent(s): fba73d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -34
app.py CHANGED
@@ -1,41 +1,40 @@
1
  import streamlit as st
2
  from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
3
  from datasets import load_dataset
4
- import torch
5
 
6
- # Load the dataset
7
  dataset = load_dataset("pubmed_qa", split="test")
8
-
9
- # Initialize RAG components
10
  tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
11
- retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="default", use_dummy_dataset=True)
12
- model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq")
13
-
14
- # Function to get the answer to a medical query
15
- def get_medical_answer(query):
16
- # Encode the query to retrieve relevant documents
17
- inputs = tokenizer(query, return_tensors="pt")
18
- input_ids = inputs["input_ids"]
19
-
20
- # Retrieve relevant documents
21
- docs = retriever(input_ids=input_ids, return_tensors="pt")
22
-
23
- # Generate the answer from the model
24
- generated_ids = model.generate(input_ids=input_ids, context_input_ids=docs["context_input_ids"],
25
- context_attention_mask=docs["context_attention_mask"])
26
-
27
- # Decode the generated answer
28
- generated_answer = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
29
- return generated_answer
30
-
31
- # Streamlit UI
32
- st.title("Medical QA Assistant")
33
- st.write("Ask any medical question, and I will answer it based on PubMed papers!")
34
-
35
- # Input text box for queries
36
- query = st.text_input("Enter your medical question:")
37
-
38
- if query:
39
- with st.spinner("Searching for the answer..."):
40
- answer = get_medical_answer(query)
41
  st.write(f"Answer: {answer}")
 
 
 
 
 
 
1
  import streamlit as st
2
  from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
3
  from datasets import load_dataset
 
4
 
5
+ # Load dataset (pubmed_qa) and tokenizer
6
  dataset = load_dataset("pubmed_qa", split="test")
 
 
7
  tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
8
+ retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="compressed", passages_path="./path_to_dataset")
9
+
10
+ # Initialize the RAG model
11
+ model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq")
12
+
13
+ # Define Streamlit app
14
+ st.title('Medical QA Assistant')
15
+
16
+ st.markdown("This app uses a RAG model to answer medical queries based on the PubMed QA dataset.")
17
+
18
+ # User input for query
19
+ user_query = st.text_input("Ask a medical question:")
20
+
21
+ if user_query:
22
+ # Tokenize input question and retrieve related documents
23
+ inputs = tokenizer(user_query, return_tensors="pt")
24
+ input_ids = inputs['input_ids']
25
+ question_encoder_outputs = model.question_encoder(input_ids)
26
+
27
+ # Use the retriever to get context
28
+ retrieved_docs = retriever.retrieve(input_ids)
29
+
30
+ # Generate an answer based on the context
31
+ generated_ids = model.generate(input_ids, context_input_ids=retrieved_docs)
32
+ answer = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
33
+
34
+ # Show the answer
 
 
 
35
  st.write(f"Answer: {answer}")
36
+
37
+ # Display the most relevant documents
38
+ st.subheader("Relevant Documents:")
39
+ for doc in retrieved_docs:
40
+ st.write(doc['text'][:300] + '...') # Display first 300 characters of each doc