Update app.py
Browse files
app.py
CHANGED
@@ -1,256 +1,3 @@
|
|
1 |
-
'''
|
2 |
-
|
3 |
-
import os
|
4 |
-
import re
|
5 |
-
import json
|
6 |
-
import torch
|
7 |
-
import numpy as np
|
8 |
-
import logging
|
9 |
-
from typing import Dict, List, Tuple, Optional
|
10 |
-
from tqdm import tqdm
|
11 |
-
from pydantic import BaseModel
|
12 |
-
from transformers import (
|
13 |
-
AutoTokenizer,
|
14 |
-
AutoModelForSeq2SeqLM,
|
15 |
-
AutoModelForQuestionAnswering,
|
16 |
-
pipeline,
|
17 |
-
LogitsProcessor,
|
18 |
-
LogitsProcessorList,
|
19 |
-
PreTrainedModel,
|
20 |
-
PreTrainedTokenizer
|
21 |
-
)
|
22 |
-
from sentence_transformers import SentenceTransformer, CrossEncoder
|
23 |
-
from sklearn.feature_extraction.text import TfidfVectorizer
|
24 |
-
from rank_bm25 import BM25Okapi
|
25 |
-
import PyPDF2
|
26 |
-
from sklearn.cluster import KMeans
|
27 |
-
import spacy
|
28 |
-
import subprocess
|
29 |
-
import gradio as gr
|
30 |
-
|
31 |
-
logging.basicConfig(
|
32 |
-
level=logging.INFO,
|
33 |
-
format="%(asctime)s [%(levelname)s] %(message)s"
|
34 |
-
)
|
35 |
-
|
36 |
-
class ConfidenceCalibrator(LogitsProcessor):
|
37 |
-
def __init__(self, calibration_factor: float = 0.9):
|
38 |
-
self.calibration_factor = calibration_factor
|
39 |
-
|
40 |
-
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
41 |
-
return scores / self.calibration_factor
|
42 |
-
|
43 |
-
class DocumentResult(BaseModel):
|
44 |
-
content: str
|
45 |
-
confidence: float
|
46 |
-
source_page: int
|
47 |
-
supporting_evidence: List[str]
|
48 |
-
|
49 |
-
class OptimalModelSelector:
|
50 |
-
def __init__(self):
|
51 |
-
self.qa_models = {
|
52 |
-
"deberta-v3": ("deepset/deberta-v3-large-squad2", 0.87)
|
53 |
-
}
|
54 |
-
self.summarization_models = {
|
55 |
-
"bart": ("facebook/bart-large-cnn", 0.85)
|
56 |
-
}
|
57 |
-
self.current_models = {}
|
58 |
-
|
59 |
-
def get_best_model(self, task_type: str) -> Tuple[PreTrainedModel, PreTrainedTokenizer, float]:
|
60 |
-
model_map = self.qa_models if "qa" in task_type else self.summarization_models
|
61 |
-
best_model_name, best_score = max(model_map.items(), key=lambda x: x[1][1])
|
62 |
-
if best_model_name not in self.current_models:
|
63 |
-
tokenizer = AutoTokenizer.from_pretrained(model_map[best_model_name][0])
|
64 |
-
model = (AutoModelForQuestionAnswering if "qa" in task_type
|
65 |
-
else AutoModelForSeq2SeqLM).from_pretrained(model_map[best_model_name][0])
|
66 |
-
model = model.eval().half().to('cuda' if torch.cuda.is_available() else 'cpu')
|
67 |
-
self.current_models[best_model_name] = (model, tokenizer)
|
68 |
-
return *self.current_models[best_model_name], best_score
|
69 |
-
|
70 |
-
class PDFAugmentedRetriever:
|
71 |
-
def __init__(self, document_texts: List[str]):
|
72 |
-
self.documents = [(i, text) for i, text in enumerate(document_texts)]
|
73 |
-
self.bm25 = BM25Okapi([text.split() for _, text in self.documents])
|
74 |
-
self.encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
|
75 |
-
self.tfidf = TfidfVectorizer(stop_words='english').fit([text for _, text in self.documents])
|
76 |
-
|
77 |
-
def retrieve(self, query: str, top_k: int = 5) -> List[Tuple[int, str, float]]:
|
78 |
-
bm25_scores = self.bm25.get_scores(query.split())
|
79 |
-
semantic_scores = self.encoder.predict([(query, doc) for _, doc in self.documents])
|
80 |
-
combined_scores = 0.4 * bm25_scores + 0.6 * np.array(semantic_scores)
|
81 |
-
top_indices = np.argsort(combined_scores)[-top_k:][::-1]
|
82 |
-
return [(self.documents[i][0], self.documents[i][1], float(combined_scores[i]))
|
83 |
-
for i in top_indices]
|
84 |
-
|
85 |
-
class DetailedExplainer:
|
86 |
-
def __init__(self,
|
87 |
-
explanation_model: str = "google/flan-t5-large",
|
88 |
-
device: int = 0):
|
89 |
-
try:
|
90 |
-
self.nlp = spacy.load("en_core_web_sm")
|
91 |
-
except OSError:
|
92 |
-
subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"], check=True)
|
93 |
-
self.nlp = spacy.load("en_core_web_sm")
|
94 |
-
self.explainer = pipeline(
|
95 |
-
"text2text-generation",
|
96 |
-
model=explanation_model,
|
97 |
-
tokenizer=explanation_model,
|
98 |
-
device=device,
|
99 |
-
max_length=500,
|
100 |
-
max_new_tokens=800
|
101 |
-
)
|
102 |
-
|
103 |
-
def extract_concepts(self, text: str) -> list:
|
104 |
-
doc = self.nlp(text)
|
105 |
-
concepts = set()
|
106 |
-
for chunk in doc.noun_chunks:
|
107 |
-
if len(chunk) > 1 and not chunk.root.is_stop:
|
108 |
-
concepts.add(chunk.text.strip())
|
109 |
-
for ent in doc.ents:
|
110 |
-
if ent.label_ in ["PERSON", "ORG", "GPE", "NORP", "EVENT", "WORK_OF_ART"]:
|
111 |
-
concepts.add(ent.text.strip())
|
112 |
-
return list(concepts)
|
113 |
-
|
114 |
-
def explain_concept(self, concept: str, context: str, min_accuracy: float = 0.50) -> str:
|
115 |
-
prompt = (
|
116 |
-
f"The following sentence from a PDF is given \n{context}\n\n\nNow explain the concept '{concept}' mentioned above with at least {int(min_accuracy * 100)}% accuracy."
|
117 |
-
)
|
118 |
-
result = self.explainer(
|
119 |
-
prompt,
|
120 |
-
do_sample=False
|
121 |
-
)
|
122 |
-
return result[0]["generated_text"].strip()
|
123 |
-
|
124 |
-
def explain_text(self, text: str, context: str) -> dict:
|
125 |
-
concepts = self.extract_concepts(text)
|
126 |
-
explanations = {}
|
127 |
-
for concept in concepts:
|
128 |
-
explanations[concept] = self.explain_concept(concept, context)
|
129 |
-
return {"concepts": concepts, "explanations": explanations}
|
130 |
-
|
131 |
-
class AdvancedPDFAnalyzer:
|
132 |
-
def __init__(self):
|
133 |
-
self.logger = logging.getLogger("PDFAnalyzer")
|
134 |
-
self.model_selector = OptimalModelSelector()
|
135 |
-
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
136 |
-
self.qa_model, self.qa_tokenizer, _ = self.model_selector.get_best_model("qa")
|
137 |
-
self.qa_model = self.qa_model.to(self.device)
|
138 |
-
self.summarizer = pipeline(
|
139 |
-
"summarization",
|
140 |
-
model="facebook/bart-large-cnn",
|
141 |
-
device=0 if torch.cuda.is_available() else -1,
|
142 |
-
framework="pt"
|
143 |
-
)
|
144 |
-
self.logits_processor = LogitsProcessorList([
|
145 |
-
ConfidenceCalibrator(calibration_factor=0.85)
|
146 |
-
])
|
147 |
-
self.detailed_explainer = DetailedExplainer(device=0 if torch.cuda.is_available() else -1)
|
148 |
-
|
149 |
-
def extract_text_with_metadata(self, file_path: str) -> List[Dict]:
|
150 |
-
documents = []
|
151 |
-
with open(file_path, 'rb') as f:
|
152 |
-
reader = PyPDF2.PdfReader(f)
|
153 |
-
for i, page in enumerate(reader.pages):
|
154 |
-
text = page.extract_text()
|
155 |
-
if not text or not text.strip():
|
156 |
-
continue
|
157 |
-
page_number = i + 1
|
158 |
-
metadata = {
|
159 |
-
'source': os.path.basename(file_path),
|
160 |
-
'page': page_number,
|
161 |
-
'char_count': len(text),
|
162 |
-
'word_count': len(text.split()),
|
163 |
-
}
|
164 |
-
documents.append({
|
165 |
-
'content': self._clean_text(text),
|
166 |
-
'metadata': metadata
|
167 |
-
})
|
168 |
-
if not documents:
|
169 |
-
raise ValueError("No extractable content found in PDF")
|
170 |
-
return documents
|
171 |
-
|
172 |
-
def _clean_text(self, text: str) -> str:
|
173 |
-
text = re.sub(r'[\x00-\x1F\x7F-\x9F]', ' ', text)
|
174 |
-
text = re.sub(r'\s+', ' ', text)
|
175 |
-
text = re.sub(r'(\w)-\s+(\w)', r'\1\2', text)
|
176 |
-
return text.strip()
|
177 |
-
|
178 |
-
def answer_question(self, question: str, documents: List[Dict]) -> Dict:
|
179 |
-
retriever = PDFAugmentedRetriever([doc['content'] for doc in documents])
|
180 |
-
relevant_contexts = retriever.retrieve(question, top_k=3)
|
181 |
-
answers = []
|
182 |
-
for page_idx, context, similarity_score in relevant_contexts:
|
183 |
-
inputs = self.qa_tokenizer(
|
184 |
-
question,
|
185 |
-
context,
|
186 |
-
add_special_tokens=True,
|
187 |
-
return_tensors="pt",
|
188 |
-
max_length=512,
|
189 |
-
truncation=True
|
190 |
-
)
|
191 |
-
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
192 |
-
with torch.no_grad():
|
193 |
-
outputs = self.qa_model(**inputs)
|
194 |
-
start_logits = outputs.start_logits
|
195 |
-
end_logits = outputs.end_logits
|
196 |
-
logits_processor = LogitsProcessorList([ConfidenceCalibrator()])
|
197 |
-
start_logits = logits_processor(inputs['input_ids'], start_logits)
|
198 |
-
end_logits = logits_processor(inputs['input_ids'], end_logits)
|
199 |
-
start_prob = torch.nn.functional.softmax(start_logits, dim=-1)
|
200 |
-
end_prob = torch.nn.functional.softmax(end_logits, dim=-1)
|
201 |
-
max_start_score, max_start_idx = torch.max(start_prob, dim=-1)
|
202 |
-
max_start_idx_int = max_start_idx.item()
|
203 |
-
max_end_score, max_end_idx = torch.max(end_prob[0, max_start_idx_int:], dim=-1)
|
204 |
-
max_end_idx_int = max_end_idx.item() + max_start_idx_int
|
205 |
-
confidence = float((max_start_score * max_end_score) * 0.9 * similarity_score)
|
206 |
-
answer_tokens = inputs["input_ids"][0][max_start_idx_int:max_end_idx_int + 1]
|
207 |
-
answer = self.qa_tokenizer.decode(answer_tokens, skip_special_tokens=True)
|
208 |
-
explanations_result = self.detailed_explainer.explain_text(answer, context)
|
209 |
-
answers.append({
|
210 |
-
"answer": answer,
|
211 |
-
"confidence": confidence,
|
212 |
-
"context": context,
|
213 |
-
"page_number": documents[page_idx]['metadata']['page'],
|
214 |
-
"explanations": explanations_result
|
215 |
-
})
|
216 |
-
if not answers:
|
217 |
-
return {"answer": "No confident answer found", "confidence": 0.0, "explanations": {}}
|
218 |
-
best_answer = max(answers, key=lambda x: x['confidence'])
|
219 |
-
if best_answer['confidence'] < 0.85:
|
220 |
-
best_answer['answer'] = f"[Low Confidence] {best_answer['answer']}"
|
221 |
-
return answers #MODUSTIFIED HERE YOU REMOVE THIS THIS LINE OF CODE IF IT CRASHES, DAT 10TH AUG, 11:49AM
|
222 |
-
return best_answer
|
223 |
-
|
224 |
-
analyzer = AdvancedPDFAnalyzer()
|
225 |
-
documents = analyzer.extract_text_with_metadata("example.pdf")
|
226 |
-
|
227 |
-
def ask_question_gradio(question: str):
|
228 |
-
if not question.strip():
|
229 |
-
return "Please enter a valid question."
|
230 |
-
try:
|
231 |
-
result = analyzer.answer_question(question, documents)
|
232 |
-
answer = result['answer']
|
233 |
-
confidence = result['confidence']
|
234 |
-
explanation = "\n\n".join(
|
235 |
-
f"🔹 {concept}: {desc}"
|
236 |
-
for concept, desc in result.get("explanations", {}).get("explanations", {}).items()
|
237 |
-
)
|
238 |
-
return f"📌 **Answer**: {answer}\n\n🔒 **Confidence**: {confidence:.2f}\n\n📘 **Explanations**:\n{explanation}"
|
239 |
-
except Exception as e:
|
240 |
-
return f"❌ Error: {str(e)}"
|
241 |
-
|
242 |
-
demo = gr.Interface(
|
243 |
-
fn=ask_question_gradio,
|
244 |
-
inputs=gr.Textbox(label="Ask a question about the PDF"),
|
245 |
-
outputs=gr.Markdown(label="Answer"),
|
246 |
-
title="Quandans AI - Ask Questions",
|
247 |
-
description="Ask a question based on the document loaded in this system."
|
248 |
-
)
|
249 |
-
|
250 |
-
demo.launch()
|
251 |
-
'''
|
252 |
-
|
253 |
-
|
254 |
import os
|
255 |
import re
|
256 |
import json
|
@@ -325,7 +72,7 @@ class PDFAugmentedRetriever:
|
|
325 |
self.encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
|
326 |
self.tfidf = TfidfVectorizer(stop_words='english').fit([text for _, text in self.documents])
|
327 |
|
328 |
-
def retrieve(self, query: str, top_k: int =
|
329 |
bm25_scores = self.bm25.get_scores(query.split())
|
330 |
semantic_scores = self.encoder.predict([(query, doc) for _, doc in self.documents])
|
331 |
combined_scores = 0.4 * bm25_scores + 0.6 * np.array(semantic_scores)
|
@@ -347,8 +94,8 @@ class DetailedExplainer:
|
|
347 |
model=explanation_model,
|
348 |
tokenizer=explanation_model,
|
349 |
device=device,
|
350 |
-
max_length=
|
351 |
-
max_new_tokens=
|
352 |
)
|
353 |
|
354 |
def extract_concepts(self, text: str) -> list:
|
@@ -364,11 +111,17 @@ class DetailedExplainer:
|
|
364 |
|
365 |
def explain_concept(self, concept: str, context: str, min_accuracy: float = 0.50) -> str:
|
366 |
prompt = (
|
367 |
-
f"The following sentence from a PDF is given \n{context}\n\n\
|
|
|
|
|
|
|
|
|
368 |
)
|
369 |
result = self.explainer(
|
370 |
prompt,
|
371 |
-
do_sample=False
|
|
|
|
|
372 |
)
|
373 |
return result[0]["generated_text"].strip()
|
374 |
|
@@ -390,7 +143,9 @@ class AdvancedPDFAnalyzer:
|
|
390 |
"summarization",
|
391 |
model="facebook/bart-large-cnn",
|
392 |
device=0 if torch.cuda.is_available() else -1,
|
393 |
-
framework="pt"
|
|
|
|
|
394 |
)
|
395 |
self.logits_processor = LogitsProcessorList([
|
396 |
ConfidenceCalibrator(calibration_factor=0.85)
|
@@ -428,7 +183,7 @@ class AdvancedPDFAnalyzer:
|
|
428 |
|
429 |
def answer_question(self, question: str, documents: List[Dict]) -> Dict:
|
430 |
retriever = PDFAugmentedRetriever([doc['content'] for doc in documents])
|
431 |
-
relevant_contexts = retriever.retrieve(question, top_k=
|
432 |
answers = []
|
433 |
|
434 |
for page_idx, context, similarity_score in relevant_contexts:
|
@@ -437,8 +192,9 @@ class AdvancedPDFAnalyzer:
|
|
437 |
context,
|
438 |
add_special_tokens=True,
|
439 |
return_tensors="pt",
|
440 |
-
max_length=512
|
441 |
-
truncation=True
|
|
|
442 |
)
|
443 |
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
444 |
|
@@ -463,6 +219,16 @@ class AdvancedPDFAnalyzer:
|
|
463 |
answer_tokens = inputs["input_ids"][0][max_start_idx_int:max_end_idx_int + 1]
|
464 |
answer = self.qa_tokenizer.decode(answer_tokens, skip_special_tokens=True)
|
465 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
466 |
# Only generate explanations if we have a valid answer
|
467 |
explanations_result = {"concepts": [], "explanations": {}}
|
468 |
if answer and answer.strip():
|
@@ -491,6 +257,23 @@ class AdvancedPDFAnalyzer:
|
|
491 |
# Get the best answer based on confidence
|
492 |
best_answer = max(answers, key=lambda x: x['confidence'])
|
493 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
494 |
# FIXED: Always return the best answer dictionary, just modify the answer text if confidence is low
|
495 |
if best_answer['confidence'] < 0.3: # Lowered threshold to be more permissive
|
496 |
best_answer['answer'] = f"[Low Confidence] {best_answer['answer']}"
|
@@ -562,15 +345,27 @@ else:
|
|
562 |
|
563 |
demo = gr.Interface(
|
564 |
fn=ask_question_gradio,
|
565 |
-
inputs=gr.Textbox(
|
566 |
-
|
567 |
-
|
568 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
569 |
examples=[
|
570 |
"What is the main topic of this document?",
|
571 |
-
"
|
572 |
-
"What are the conclusions mentioned?"
|
573 |
-
|
|
|
|
|
|
|
574 |
)
|
575 |
|
576 |
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
import re
|
3 |
import json
|
|
|
72 |
self.encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
|
73 |
self.tfidf = TfidfVectorizer(stop_words='english').fit([text for _, text in self.documents])
|
74 |
|
75 |
+
def retrieve(self, query: str, top_k: int = 8) -> List[Tuple[int, str, float]]: # Increased from 5 to 8
|
76 |
bm25_scores = self.bm25.get_scores(query.split())
|
77 |
semantic_scores = self.encoder.predict([(query, doc) for _, doc in self.documents])
|
78 |
combined_scores = 0.4 * bm25_scores + 0.6 * np.array(semantic_scores)
|
|
|
94 |
model=explanation_model,
|
95 |
tokenizer=explanation_model,
|
96 |
device=device,
|
97 |
+
max_length=2048,
|
98 |
+
max_new_tokens=2000
|
99 |
)
|
100 |
|
101 |
def extract_concepts(self, text: str) -> list:
|
|
|
111 |
|
112 |
def explain_concept(self, concept: str, context: str, min_accuracy: float = 0.50) -> str:
|
113 |
prompt = (
|
114 |
+
f"The following sentence from a PDF is given \n{context}\n\n\n"
|
115 |
+
f"Now provide a detailed explanation of the concept '{concept}' mentioned above. "
|
116 |
+
f"Include background information, context, examples, and significance. "
|
117 |
+
f"Write a comprehensive explanation with at least {int(min_accuracy * 100)}% accuracy. "
|
118 |
+
f"Make the explanation thorough and informative, up to 500 words if needed."
|
119 |
)
|
120 |
result = self.explainer(
|
121 |
prompt,
|
122 |
+
do_sample=False,
|
123 |
+
max_length=2048,
|
124 |
+
max_new_tokens=600
|
125 |
)
|
126 |
return result[0]["generated_text"].strip()
|
127 |
|
|
|
143 |
"summarization",
|
144 |
model="facebook/bart-large-cnn",
|
145 |
device=0 if torch.cuda.is_available() else -1,
|
146 |
+
framework="pt",
|
147 |
+
max_length=2048,
|
148 |
+
min_length=100
|
149 |
)
|
150 |
self.logits_processor = LogitsProcessorList([
|
151 |
ConfidenceCalibrator(calibration_factor=0.85)
|
|
|
183 |
|
184 |
def answer_question(self, question: str, documents: List[Dict]) -> Dict:
|
185 |
retriever = PDFAugmentedRetriever([doc['content'] for doc in documents])
|
186 |
+
relevant_contexts = retriever.retrieve(question, top_k=5) # Increased context retrieval
|
187 |
answers = []
|
188 |
|
189 |
for page_idx, context, similarity_score in relevant_contexts:
|
|
|
192 |
context,
|
193 |
add_special_tokens=True,
|
194 |
return_tensors="pt",
|
195 |
+
max_length=1024, # Increased from 512
|
196 |
+
truncation=True,
|
197 |
+
padding=True
|
198 |
)
|
199 |
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
200 |
|
|
|
219 |
answer_tokens = inputs["input_ids"][0][max_start_idx_int:max_end_idx_int + 1]
|
220 |
answer = self.qa_tokenizer.decode(answer_tokens, skip_special_tokens=True)
|
221 |
|
222 |
+
# Enhanced answer extraction for longer responses
|
223 |
+
if len(answer.strip()) < 20: # If answer is too short, try extracting more context
|
224 |
+
# Get more surrounding context
|
225 |
+
extended_start = max(0, max_start_idx_int - 50)
|
226 |
+
extended_end = min(len(inputs["input_ids"][0]), max_end_idx_int + 150)
|
227 |
+
extended_tokens = inputs["input_ids"][0][extended_start:extended_end]
|
228 |
+
extended_answer = self.qa_tokenizer.decode(extended_tokens, skip_special_tokens=True)
|
229 |
+
if len(extended_answer.strip()) > len(answer.strip()):
|
230 |
+
answer = extended_answer
|
231 |
+
|
232 |
# Only generate explanations if we have a valid answer
|
233 |
explanations_result = {"concepts": [], "explanations": {}}
|
234 |
if answer and answer.strip():
|
|
|
257 |
# Get the best answer based on confidence
|
258 |
best_answer = max(answers, key=lambda x: x['confidence'])
|
259 |
|
260 |
+
# For comprehensive responses, combine information from multiple high-confidence answers
|
261 |
+
if len(answers) > 1:
|
262 |
+
high_confidence_answers = [a for a in answers if a['confidence'] > 0.2]
|
263 |
+
if len(high_confidence_answers) > 1:
|
264 |
+
# Combine explanations from multiple sources
|
265 |
+
combined_explanations = {}
|
266 |
+
all_concepts = set()
|
267 |
+
|
268 |
+
for ans in high_confidence_answers[:3]: # Use top 3 answers
|
269 |
+
explanations = ans.get("explanations", {}).get("explanations", {})
|
270 |
+
concepts = ans.get("explanations", {}).get("concepts", [])
|
271 |
+
all_concepts.update(concepts)
|
272 |
+
combined_explanations.update(explanations)
|
273 |
+
|
274 |
+
best_answer["explanations"]["explanations"] = combined_explanations
|
275 |
+
best_answer["explanations"]["concepts"] = list(all_concepts)
|
276 |
+
|
277 |
# FIXED: Always return the best answer dictionary, just modify the answer text if confidence is low
|
278 |
if best_answer['confidence'] < 0.3: # Lowered threshold to be more permissive
|
279 |
best_answer['answer'] = f"[Low Confidence] {best_answer['answer']}"
|
|
|
345 |
|
346 |
demo = gr.Interface(
|
347 |
fn=ask_question_gradio,
|
348 |
+
inputs=gr.Textbox(
|
349 |
+
label="Ask a question about the PDF",
|
350 |
+
placeholder="Type your question here...",
|
351 |
+
lines=3,
|
352 |
+
max_lines=5
|
353 |
+
),
|
354 |
+
outputs=gr.Markdown(
|
355 |
+
label="Answer",
|
356 |
+
value="",
|
357 |
+
show_copy_button=True
|
358 |
+
),
|
359 |
+
title="Quandans AI - Ask Questions (Up to 2000 words)",
|
360 |
+
description="Ask a question based on the document loaded in this system. The system can now provide comprehensive answers up to 2000 words with detailed explanations.",
|
361 |
examples=[
|
362 |
"What is the main topic of this document?",
|
363 |
+
"Provide a detailed summary of the key points from page 1",
|
364 |
+
"What are the conclusions mentioned and explain them in detail?",
|
365 |
+
"Give me a comprehensive overview of all the important concepts discussed"
|
366 |
+
],
|
367 |
+
theme=gr.themes.Soft(),
|
368 |
+
allow_flagging="never"
|
369 |
)
|
370 |
|
371 |
demo.launch()
|