Update app.py
Browse files
app.py
CHANGED
|
@@ -67,16 +67,16 @@ class Retriever:
|
|
| 67 |
def load_chunks(self):
|
| 68 |
self.text = self.extract_text_from_pdf(self.file_path)
|
| 69 |
text_splitter = RecursiveCharacterTextSplitter(
|
| 70 |
-
chunk_size=
|
| 71 |
chunk_overlap=20,
|
| 72 |
length_function=self.token_len,
|
| 73 |
-
separators=["\n\n", "
|
| 74 |
)
|
| 75 |
|
| 76 |
self.chunks = text_splitter.split_text(self.text)
|
| 77 |
|
| 78 |
def load_context_embeddings(self):
|
| 79 |
-
encoded_input = self.context_tokenizer(self.chunks, return_tensors='pt', padding=True, truncation=True, max_length=
|
| 80 |
|
| 81 |
with torch.no_grad():
|
| 82 |
model_output = self.context_model(**encoded_input)
|
|
@@ -89,20 +89,16 @@ class Retriever:
|
|
| 89 |
encoded_query = self.question_tokenizer(query_prompt, return_tensors="pt", truncation=True, padding=True).to(device)
|
| 90 |
|
| 91 |
with torch.no_grad():
|
| 92 |
-
|
| 93 |
-
|
| 94 |
|
| 95 |
query_vector_np = query_vector.cpu().numpy()
|
| 96 |
D, I = self.index.search(query_vector_np, k)
|
| 97 |
|
| 98 |
-
retrieved_texts = [self.chunks[i] for i in I[0]]
|
| 99 |
|
| 100 |
scores = [d for d in D[0]]
|
| 101 |
|
| 102 |
-
# print("Top 5 retrieved texts and their associated scores:")
|
| 103 |
-
# for idx, (text, score) in enumerate(zip(retrieved_texts, scores)):
|
| 104 |
-
# print(f"{idx + 1}. Text: {text} \n Score: {score:.4f}\n")
|
| 105 |
-
|
| 106 |
return retrieved_texts
|
| 107 |
|
| 108 |
class RAG:
|
|
@@ -115,22 +111,23 @@ class RAG:
|
|
| 115 |
|
| 116 |
# generator_name = "valhalla/bart-large-finetuned-squadv1"
|
| 117 |
# generator_name = "'vblagoje/bart_lfqa'"
|
| 118 |
-
generator_name = "a-ware/bart-squadv2"
|
| 119 |
-
|
| 120 |
self.generator_tokenizer = BartTokenizer.from_pretrained(generator_name)
|
| 121 |
self.generator_model = BartForConditionalGeneration.from_pretrained(generator_name).to(device)
|
| 122 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
self.retriever = Retriever(file_path, device, context_model_name, question_model_name)
|
| 124 |
self.retriever.load_chunks()
|
| 125 |
self.retriever.load_context_embeddings()
|
| 126 |
|
| 127 |
-
def get_answer(self, question, context):
|
| 128 |
-
input_text = "context: %s <question for context: %s </s>" % (context,question)
|
| 129 |
-
features = self.generator_tokenizer([input_text], return_tensors='pt')
|
| 130 |
-
out = self.generator_model.generate(input_ids=features['input_ids'].to(device), attention_mask=features['attention_mask'].to(device))
|
| 131 |
-
return self.generator_tokenizer.decode(out[0])
|
| 132 |
|
| 133 |
-
def
|
| 134 |
context = self.retriever.retrieve_top_k(question, k=5)
|
| 135 |
# input_text = question + " " + " ".join(context)
|
| 136 |
|
|
@@ -144,22 +141,46 @@ class RAG:
|
|
| 144 |
answer = self.generator_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 145 |
return answer
|
| 146 |
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
question_model_name="facebook/dpr-question_encoder-multiset-base"
|
| 151 |
|
| 152 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
|
| 154 |
-
|
|
|
|
| 155 |
|
| 156 |
-
print(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
|
| 158 |
st.title("RAG Model Query Interface")
|
| 159 |
|
| 160 |
-
|
|
|
|
|
|
|
| 161 |
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
st.write(answer)
|
|
|
|
| 67 |
def load_chunks(self):
|
| 68 |
self.text = self.extract_text_from_pdf(self.file_path)
|
| 69 |
text_splitter = RecursiveCharacterTextSplitter(
|
| 70 |
+
chunk_size=150,
|
| 71 |
chunk_overlap=20,
|
| 72 |
length_function=self.token_len,
|
| 73 |
+
separators=["Section", "\n\n", "\n", ".", " ", ""]
|
| 74 |
)
|
| 75 |
|
| 76 |
self.chunks = text_splitter.split_text(self.text)
|
| 77 |
|
| 78 |
def load_context_embeddings(self):
|
| 79 |
+
encoded_input = self.context_tokenizer(self.chunks, return_tensors='pt', padding=True, truncation=True, max_length=300).to(device)
|
| 80 |
|
| 81 |
with torch.no_grad():
|
| 82 |
model_output = self.context_model(**encoded_input)
|
|
|
|
| 89 |
encoded_query = self.question_tokenizer(query_prompt, return_tensors="pt", truncation=True, padding=True).to(device)
|
| 90 |
|
| 91 |
with torch.no_grad():
|
| 92 |
+
model_output = self.question_model(**encoded_query)
|
| 93 |
+
query_vector = model_output.pooler_output
|
| 94 |
|
| 95 |
query_vector_np = query_vector.cpu().numpy()
|
| 96 |
D, I = self.index.search(query_vector_np, k)
|
| 97 |
|
| 98 |
+
retrieved_texts = [' '.join(self.chunks[i].split('\n')) for i in I[0]] # Replacing newlines with spaces
|
| 99 |
|
| 100 |
scores = [d for d in D[0]]
|
| 101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
return retrieved_texts
|
| 103 |
|
| 104 |
class RAG:
|
|
|
|
| 111 |
|
| 112 |
# generator_name = "valhalla/bart-large-finetuned-squadv1"
|
| 113 |
# generator_name = "'vblagoje/bart_lfqa'"
|
| 114 |
+
# generator_name = "a-ware/bart-squadv2"
|
| 115 |
+
|
| 116 |
self.generator_tokenizer = BartTokenizer.from_pretrained(generator_name)
|
| 117 |
self.generator_model = BartForConditionalGeneration.from_pretrained(generator_name).to(device)
|
| 118 |
|
| 119 |
+
# generator_name = "MaRiOrOsSi/t5-base-finetuned-question-answering"
|
| 120 |
+
# generator_name = "t5-small"
|
| 121 |
+
|
| 122 |
+
# self.generator_tokenizer = T5Tokenizer.from_pretrained(generator_name)
|
| 123 |
+
# self.generator_model = T5ForConditionalGeneration.from_pretrained(generator_name)
|
| 124 |
+
|
| 125 |
self.retriever = Retriever(file_path, device, context_model_name, question_model_name)
|
| 126 |
self.retriever.load_chunks()
|
| 127 |
self.retriever.load_context_embeddings()
|
| 128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
|
| 130 |
+
def abstractive_query(self, question):
|
| 131 |
context = self.retriever.retrieve_top_k(question, k=5)
|
| 132 |
# input_text = question + " " + " ".join(context)
|
| 133 |
|
|
|
|
| 141 |
answer = self.generator_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 142 |
return answer
|
| 143 |
|
| 144 |
+
def extractive_query(self, question):
|
| 145 |
+
context = self.retriever.retrieve_top_k(question, k=15)
|
| 146 |
+
generator_name = "valhalla/bart-large-finetuned-squadv1"
|
| 147 |
|
| 148 |
+
self.generator_tokenizer = AutoTokenizer.from_pretrained(generator_name)
|
| 149 |
+
self.generator_model = BartForQuestionAnswering.from_pretrained(generator_name).to(device)
|
|
|
|
| 150 |
|
| 151 |
+
inputs = self.generator_tokenizer(question, ". ".join(context), return_tensors="pt", truncation=True, max_length=200 , padding="max_length")
|
| 152 |
+
with torch.no_grad():
|
| 153 |
+
model_inputs = inputs.to(device)
|
| 154 |
+
outputs = self.generator_model(**model_inputs)
|
| 155 |
+
|
| 156 |
+
answer_start_index = outputs.start_logits.argmax()
|
| 157 |
+
answer_end_index = outputs.end_logits.argmax()
|
| 158 |
|
| 159 |
+
if answer_end_index < answer_start_index:
|
| 160 |
+
answer_start_index, answer_end_index = answer_end_index, answer_start_index
|
| 161 |
|
| 162 |
+
print(answer_start_index, answer_end_index)
|
| 163 |
+
|
| 164 |
+
predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1]
|
| 165 |
+
answer = self.generator_tokenizer.decode(predict_answer_tokens, skip_special_tokens=True)
|
| 166 |
+
answer = answer.replace('\n', ' ').strip()
|
| 167 |
+
answer = answer.replace('$', '')
|
| 168 |
+
|
| 169 |
+
return answer
|
| 170 |
+
|
| 171 |
+
context_model_name="facebook/dpr-ctx_encoder-single-nq-base"
|
| 172 |
+
question_model_name = "facebook/dpr-question_encoder-single-nq-base"
|
| 173 |
+
# context_model_name="facebook/dpr-ctx_encoder-multiset-base"
|
| 174 |
+
# question_model_name="facebook/dpr-question_encoder-multiset-base"
|
| 175 |
+
|
| 176 |
+
rag = RAG(file_path, device)
|
| 177 |
|
| 178 |
st.title("RAG Model Query Interface")
|
| 179 |
|
| 180 |
+
# offer to ask a question and get an answer. make it so they can ask as many questions as they want
|
| 181 |
+
|
| 182 |
+
question = st.text_input("Ask a question", "What is another name for self-attention?")
|
| 183 |
|
| 184 |
+
if st.button("Ask"):
|
| 185 |
+
answer = rag.extractive_query(question)
|
| 186 |
+
st.write(answer)
|
|
|