AmelC commited on
Commit
3d26581
·
verified ·
1 Parent(s): e7747d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -269
app.py CHANGED
@@ -1,256 +1,3 @@
1
- '''
2
-
3
- import os
4
- import re
5
- import json
6
- import torch
7
- import numpy as np
8
- import logging
9
- from typing import Dict, List, Tuple, Optional
10
- from tqdm import tqdm
11
- from pydantic import BaseModel
12
- from transformers import (
13
- AutoTokenizer,
14
- AutoModelForSeq2SeqLM,
15
- AutoModelForQuestionAnswering,
16
- pipeline,
17
- LogitsProcessor,
18
- LogitsProcessorList,
19
- PreTrainedModel,
20
- PreTrainedTokenizer
21
- )
22
- from sentence_transformers import SentenceTransformer, CrossEncoder
23
- from sklearn.feature_extraction.text import TfidfVectorizer
24
- from rank_bm25 import BM25Okapi
25
- import PyPDF2
26
- from sklearn.cluster import KMeans
27
- import spacy
28
- import subprocess
29
- import gradio as gr
30
-
31
- logging.basicConfig(
32
- level=logging.INFO,
33
- format="%(asctime)s [%(levelname)s] %(message)s"
34
- )
35
-
36
- class ConfidenceCalibrator(LogitsProcessor):
37
- def __init__(self, calibration_factor: float = 0.9):
38
- self.calibration_factor = calibration_factor
39
-
40
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
41
- return scores / self.calibration_factor
42
-
43
- class DocumentResult(BaseModel):
44
- content: str
45
- confidence: float
46
- source_page: int
47
- supporting_evidence: List[str]
48
-
49
- class OptimalModelSelector:
50
- def __init__(self):
51
- self.qa_models = {
52
- "deberta-v3": ("deepset/deberta-v3-large-squad2", 0.87)
53
- }
54
- self.summarization_models = {
55
- "bart": ("facebook/bart-large-cnn", 0.85)
56
- }
57
- self.current_models = {}
58
-
59
- def get_best_model(self, task_type: str) -> Tuple[PreTrainedModel, PreTrainedTokenizer, float]:
60
- model_map = self.qa_models if "qa" in task_type else self.summarization_models
61
- best_model_name, best_score = max(model_map.items(), key=lambda x: x[1][1])
62
- if best_model_name not in self.current_models:
63
- tokenizer = AutoTokenizer.from_pretrained(model_map[best_model_name][0])
64
- model = (AutoModelForQuestionAnswering if "qa" in task_type
65
- else AutoModelForSeq2SeqLM).from_pretrained(model_map[best_model_name][0])
66
- model = model.eval().half().to('cuda' if torch.cuda.is_available() else 'cpu')
67
- self.current_models[best_model_name] = (model, tokenizer)
68
- return *self.current_models[best_model_name], best_score
69
-
70
- class PDFAugmentedRetriever:
71
- def __init__(self, document_texts: List[str]):
72
- self.documents = [(i, text) for i, text in enumerate(document_texts)]
73
- self.bm25 = BM25Okapi([text.split() for _, text in self.documents])
74
- self.encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
75
- self.tfidf = TfidfVectorizer(stop_words='english').fit([text for _, text in self.documents])
76
-
77
- def retrieve(self, query: str, top_k: int = 5) -> List[Tuple[int, str, float]]:
78
- bm25_scores = self.bm25.get_scores(query.split())
79
- semantic_scores = self.encoder.predict([(query, doc) for _, doc in self.documents])
80
- combined_scores = 0.4 * bm25_scores + 0.6 * np.array(semantic_scores)
81
- top_indices = np.argsort(combined_scores)[-top_k:][::-1]
82
- return [(self.documents[i][0], self.documents[i][1], float(combined_scores[i]))
83
- for i in top_indices]
84
-
85
- class DetailedExplainer:
86
- def __init__(self,
87
- explanation_model: str = "google/flan-t5-large",
88
- device: int = 0):
89
- try:
90
- self.nlp = spacy.load("en_core_web_sm")
91
- except OSError:
92
- subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"], check=True)
93
- self.nlp = spacy.load("en_core_web_sm")
94
- self.explainer = pipeline(
95
- "text2text-generation",
96
- model=explanation_model,
97
- tokenizer=explanation_model,
98
- device=device,
99
- max_length=500,
100
- max_new_tokens=800
101
- )
102
-
103
- def extract_concepts(self, text: str) -> list:
104
- doc = self.nlp(text)
105
- concepts = set()
106
- for chunk in doc.noun_chunks:
107
- if len(chunk) > 1 and not chunk.root.is_stop:
108
- concepts.add(chunk.text.strip())
109
- for ent in doc.ents:
110
- if ent.label_ in ["PERSON", "ORG", "GPE", "NORP", "EVENT", "WORK_OF_ART"]:
111
- concepts.add(ent.text.strip())
112
- return list(concepts)
113
-
114
- def explain_concept(self, concept: str, context: str, min_accuracy: float = 0.50) -> str:
115
- prompt = (
116
- f"The following sentence from a PDF is given \n{context}\n\n\nNow explain the concept '{concept}' mentioned above with at least {int(min_accuracy * 100)}% accuracy."
117
- )
118
- result = self.explainer(
119
- prompt,
120
- do_sample=False
121
- )
122
- return result[0]["generated_text"].strip()
123
-
124
- def explain_text(self, text: str, context: str) -> dict:
125
- concepts = self.extract_concepts(text)
126
- explanations = {}
127
- for concept in concepts:
128
- explanations[concept] = self.explain_concept(concept, context)
129
- return {"concepts": concepts, "explanations": explanations}
130
-
131
- class AdvancedPDFAnalyzer:
132
- def __init__(self):
133
- self.logger = logging.getLogger("PDFAnalyzer")
134
- self.model_selector = OptimalModelSelector()
135
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
136
- self.qa_model, self.qa_tokenizer, _ = self.model_selector.get_best_model("qa")
137
- self.qa_model = self.qa_model.to(self.device)
138
- self.summarizer = pipeline(
139
- "summarization",
140
- model="facebook/bart-large-cnn",
141
- device=0 if torch.cuda.is_available() else -1,
142
- framework="pt"
143
- )
144
- self.logits_processor = LogitsProcessorList([
145
- ConfidenceCalibrator(calibration_factor=0.85)
146
- ])
147
- self.detailed_explainer = DetailedExplainer(device=0 if torch.cuda.is_available() else -1)
148
-
149
- def extract_text_with_metadata(self, file_path: str) -> List[Dict]:
150
- documents = []
151
- with open(file_path, 'rb') as f:
152
- reader = PyPDF2.PdfReader(f)
153
- for i, page in enumerate(reader.pages):
154
- text = page.extract_text()
155
- if not text or not text.strip():
156
- continue
157
- page_number = i + 1
158
- metadata = {
159
- 'source': os.path.basename(file_path),
160
- 'page': page_number,
161
- 'char_count': len(text),
162
- 'word_count': len(text.split()),
163
- }
164
- documents.append({
165
- 'content': self._clean_text(text),
166
- 'metadata': metadata
167
- })
168
- if not documents:
169
- raise ValueError("No extractable content found in PDF")
170
- return documents
171
-
172
- def _clean_text(self, text: str) -> str:
173
- text = re.sub(r'[\x00-\x1F\x7F-\x9F]', ' ', text)
174
- text = re.sub(r'\s+', ' ', text)
175
- text = re.sub(r'(\w)-\s+(\w)', r'\1\2', text)
176
- return text.strip()
177
-
178
- def answer_question(self, question: str, documents: List[Dict]) -> Dict:
179
- retriever = PDFAugmentedRetriever([doc['content'] for doc in documents])
180
- relevant_contexts = retriever.retrieve(question, top_k=3)
181
- answers = []
182
- for page_idx, context, similarity_score in relevant_contexts:
183
- inputs = self.qa_tokenizer(
184
- question,
185
- context,
186
- add_special_tokens=True,
187
- return_tensors="pt",
188
- max_length=512,
189
- truncation=True
190
- )
191
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
192
- with torch.no_grad():
193
- outputs = self.qa_model(**inputs)
194
- start_logits = outputs.start_logits
195
- end_logits = outputs.end_logits
196
- logits_processor = LogitsProcessorList([ConfidenceCalibrator()])
197
- start_logits = logits_processor(inputs['input_ids'], start_logits)
198
- end_logits = logits_processor(inputs['input_ids'], end_logits)
199
- start_prob = torch.nn.functional.softmax(start_logits, dim=-1)
200
- end_prob = torch.nn.functional.softmax(end_logits, dim=-1)
201
- max_start_score, max_start_idx = torch.max(start_prob, dim=-1)
202
- max_start_idx_int = max_start_idx.item()
203
- max_end_score, max_end_idx = torch.max(end_prob[0, max_start_idx_int:], dim=-1)
204
- max_end_idx_int = max_end_idx.item() + max_start_idx_int
205
- confidence = float((max_start_score * max_end_score) * 0.9 * similarity_score)
206
- answer_tokens = inputs["input_ids"][0][max_start_idx_int:max_end_idx_int + 1]
207
- answer = self.qa_tokenizer.decode(answer_tokens, skip_special_tokens=True)
208
- explanations_result = self.detailed_explainer.explain_text(answer, context)
209
- answers.append({
210
- "answer": answer,
211
- "confidence": confidence,
212
- "context": context,
213
- "page_number": documents[page_idx]['metadata']['page'],
214
- "explanations": explanations_result
215
- })
216
- if not answers:
217
- return {"answer": "No confident answer found", "confidence": 0.0, "explanations": {}}
218
- best_answer = max(answers, key=lambda x: x['confidence'])
219
- if best_answer['confidence'] < 0.85:
220
- best_answer['answer'] = f"[Low Confidence] {best_answer['answer']}"
221
- return answers #MODUSTIFIED HERE YOU REMOVE THIS THIS LINE OF CODE IF IT CRASHES, DAT 10TH AUG, 11:49AM
222
- return best_answer
223
-
224
- analyzer = AdvancedPDFAnalyzer()
225
- documents = analyzer.extract_text_with_metadata("example.pdf")
226
-
227
- def ask_question_gradio(question: str):
228
- if not question.strip():
229
- return "Please enter a valid question."
230
- try:
231
- result = analyzer.answer_question(question, documents)
232
- answer = result['answer']
233
- confidence = result['confidence']
234
- explanation = "\n\n".join(
235
- f"🔹 {concept}: {desc}"
236
- for concept, desc in result.get("explanations", {}).get("explanations", {}).items()
237
- )
238
- return f"📌 **Answer**: {answer}\n\n🔒 **Confidence**: {confidence:.2f}\n\n📘 **Explanations**:\n{explanation}"
239
- except Exception as e:
240
- return f"❌ Error: {str(e)}"
241
-
242
- demo = gr.Interface(
243
- fn=ask_question_gradio,
244
- inputs=gr.Textbox(label="Ask a question about the PDF"),
245
- outputs=gr.Markdown(label="Answer"),
246
- title="Quandans AI - Ask Questions",
247
- description="Ask a question based on the document loaded in this system."
248
- )
249
-
250
- demo.launch()
251
- '''
252
-
253
-
254
  import os
255
  import re
256
  import json
@@ -325,7 +72,7 @@ class PDFAugmentedRetriever:
325
  self.encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
326
  self.tfidf = TfidfVectorizer(stop_words='english').fit([text for _, text in self.documents])
327
 
328
- def retrieve(self, query: str, top_k: int = 5) -> List[Tuple[int, str, float]]:
329
  bm25_scores = self.bm25.get_scores(query.split())
330
  semantic_scores = self.encoder.predict([(query, doc) for _, doc in self.documents])
331
  combined_scores = 0.4 * bm25_scores + 0.6 * np.array(semantic_scores)
@@ -347,8 +94,8 @@ class DetailedExplainer:
347
  model=explanation_model,
348
  tokenizer=explanation_model,
349
  device=device,
350
- max_length=500,
351
- max_new_tokens=800
352
  )
353
 
354
  def extract_concepts(self, text: str) -> list:
@@ -364,11 +111,17 @@ class DetailedExplainer:
364
 
365
  def explain_concept(self, concept: str, context: str, min_accuracy: float = 0.50) -> str:
366
  prompt = (
367
- f"The following sentence from a PDF is given \n{context}\n\n\nNow explain the concept '{concept}' mentioned above with at least {int(min_accuracy * 100)}% accuracy."
 
 
 
 
368
  )
369
  result = self.explainer(
370
  prompt,
371
- do_sample=False
 
 
372
  )
373
  return result[0]["generated_text"].strip()
374
 
@@ -390,7 +143,9 @@ class AdvancedPDFAnalyzer:
390
  "summarization",
391
  model="facebook/bart-large-cnn",
392
  device=0 if torch.cuda.is_available() else -1,
393
- framework="pt"
 
 
394
  )
395
  self.logits_processor = LogitsProcessorList([
396
  ConfidenceCalibrator(calibration_factor=0.85)
@@ -428,7 +183,7 @@ class AdvancedPDFAnalyzer:
428
 
429
  def answer_question(self, question: str, documents: List[Dict]) -> Dict:
430
  retriever = PDFAugmentedRetriever([doc['content'] for doc in documents])
431
- relevant_contexts = retriever.retrieve(question, top_k=3)
432
  answers = []
433
 
434
  for page_idx, context, similarity_score in relevant_contexts:
@@ -437,8 +192,9 @@ class AdvancedPDFAnalyzer:
437
  context,
438
  add_special_tokens=True,
439
  return_tensors="pt",
440
- max_length=512,
441
- truncation=True
 
442
  )
443
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
444
 
@@ -463,6 +219,16 @@ class AdvancedPDFAnalyzer:
463
  answer_tokens = inputs["input_ids"][0][max_start_idx_int:max_end_idx_int + 1]
464
  answer = self.qa_tokenizer.decode(answer_tokens, skip_special_tokens=True)
465
 
 
 
 
 
 
 
 
 
 
 
466
  # Only generate explanations if we have a valid answer
467
  explanations_result = {"concepts": [], "explanations": {}}
468
  if answer and answer.strip():
@@ -491,6 +257,23 @@ class AdvancedPDFAnalyzer:
491
  # Get the best answer based on confidence
492
  best_answer = max(answers, key=lambda x: x['confidence'])
493
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
494
  # FIXED: Always return the best answer dictionary, just modify the answer text if confidence is low
495
  if best_answer['confidence'] < 0.3: # Lowered threshold to be more permissive
496
  best_answer['answer'] = f"[Low Confidence] {best_answer['answer']}"
@@ -562,15 +345,27 @@ else:
562
 
563
  demo = gr.Interface(
564
  fn=ask_question_gradio,
565
- inputs=gr.Textbox(label="Ask a question about the PDF", placeholder="Type your question here..."),
566
- outputs=gr.Markdown(label="Answer"),
567
- title="Quandans AI - Ask Questions",
568
- description="Ask a question based on the document loaded in this system.",
 
 
 
 
 
 
 
 
 
569
  examples=[
570
  "What is the main topic of this document?",
571
- "Summarize the key points from page 1",
572
- "What are the conclusions mentioned?"
573
- ]
 
 
 
574
  )
575
 
576
  demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import re
3
  import json
 
72
  self.encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
73
  self.tfidf = TfidfVectorizer(stop_words='english').fit([text for _, text in self.documents])
74
 
75
+ def retrieve(self, query: str, top_k: int = 8) -> List[Tuple[int, str, float]]: # Increased from 5 to 8
76
  bm25_scores = self.bm25.get_scores(query.split())
77
  semantic_scores = self.encoder.predict([(query, doc) for _, doc in self.documents])
78
  combined_scores = 0.4 * bm25_scores + 0.6 * np.array(semantic_scores)
 
94
  model=explanation_model,
95
  tokenizer=explanation_model,
96
  device=device,
97
+ max_length=2048,
98
+ max_new_tokens=2000
99
  )
100
 
101
  def extract_concepts(self, text: str) -> list:
 
111
 
112
  def explain_concept(self, concept: str, context: str, min_accuracy: float = 0.50) -> str:
113
  prompt = (
114
+ f"The following sentence from a PDF is given \n{context}\n\n\n"
115
+ f"Now provide a detailed explanation of the concept '{concept}' mentioned above. "
116
+ f"Include background information, context, examples, and significance. "
117
+ f"Write a comprehensive explanation with at least {int(min_accuracy * 100)}% accuracy. "
118
+ f"Make the explanation thorough and informative, up to 500 words if needed."
119
  )
120
  result = self.explainer(
121
  prompt,
122
+ do_sample=False,
123
+ max_length=2048,
124
+ max_new_tokens=600
125
  )
126
  return result[0]["generated_text"].strip()
127
 
 
143
  "summarization",
144
  model="facebook/bart-large-cnn",
145
  device=0 if torch.cuda.is_available() else -1,
146
+ framework="pt",
147
+ max_length=2048,
148
+ min_length=100
149
  )
150
  self.logits_processor = LogitsProcessorList([
151
  ConfidenceCalibrator(calibration_factor=0.85)
 
183
 
184
  def answer_question(self, question: str, documents: List[Dict]) -> Dict:
185
  retriever = PDFAugmentedRetriever([doc['content'] for doc in documents])
186
+ relevant_contexts = retriever.retrieve(question, top_k=5) # Increased context retrieval
187
  answers = []
188
 
189
  for page_idx, context, similarity_score in relevant_contexts:
 
192
  context,
193
  add_special_tokens=True,
194
  return_tensors="pt",
195
+ max_length=1024, # Increased from 512
196
+ truncation=True,
197
+ padding=True
198
  )
199
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
200
 
 
219
  answer_tokens = inputs["input_ids"][0][max_start_idx_int:max_end_idx_int + 1]
220
  answer = self.qa_tokenizer.decode(answer_tokens, skip_special_tokens=True)
221
 
222
+ # Enhanced answer extraction for longer responses
223
+ if len(answer.strip()) < 20: # If answer is too short, try extracting more context
224
+ # Get more surrounding context
225
+ extended_start = max(0, max_start_idx_int - 50)
226
+ extended_end = min(len(inputs["input_ids"][0]), max_end_idx_int + 150)
227
+ extended_tokens = inputs["input_ids"][0][extended_start:extended_end]
228
+ extended_answer = self.qa_tokenizer.decode(extended_tokens, skip_special_tokens=True)
229
+ if len(extended_answer.strip()) > len(answer.strip()):
230
+ answer = extended_answer
231
+
232
  # Only generate explanations if we have a valid answer
233
  explanations_result = {"concepts": [], "explanations": {}}
234
  if answer and answer.strip():
 
257
  # Get the best answer based on confidence
258
  best_answer = max(answers, key=lambda x: x['confidence'])
259
 
260
+ # For comprehensive responses, combine information from multiple high-confidence answers
261
+ if len(answers) > 1:
262
+ high_confidence_answers = [a for a in answers if a['confidence'] > 0.2]
263
+ if len(high_confidence_answers) > 1:
264
+ # Combine explanations from multiple sources
265
+ combined_explanations = {}
266
+ all_concepts = set()
267
+
268
+ for ans in high_confidence_answers[:3]: # Use top 3 answers
269
+ explanations = ans.get("explanations", {}).get("explanations", {})
270
+ concepts = ans.get("explanations", {}).get("concepts", [])
271
+ all_concepts.update(concepts)
272
+ combined_explanations.update(explanations)
273
+
274
+ best_answer["explanations"]["explanations"] = combined_explanations
275
+ best_answer["explanations"]["concepts"] = list(all_concepts)
276
+
277
  # FIXED: Always return the best answer dictionary, just modify the answer text if confidence is low
278
  if best_answer['confidence'] < 0.3: # Lowered threshold to be more permissive
279
  best_answer['answer'] = f"[Low Confidence] {best_answer['answer']}"
 
345
 
346
  demo = gr.Interface(
347
  fn=ask_question_gradio,
348
+ inputs=gr.Textbox(
349
+ label="Ask a question about the PDF",
350
+ placeholder="Type your question here...",
351
+ lines=3,
352
+ max_lines=5
353
+ ),
354
+ outputs=gr.Markdown(
355
+ label="Answer",
356
+ value="",
357
+ show_copy_button=True
358
+ ),
359
+ title="Quandans AI - Ask Questions (Up to 2000 words)",
360
+ description="Ask a question based on the document loaded in this system. The system can now provide comprehensive answers up to 2000 words with detailed explanations.",
361
  examples=[
362
  "What is the main topic of this document?",
363
+ "Provide a detailed summary of the key points from page 1",
364
+ "What are the conclusions mentioned and explain them in detail?",
365
+ "Give me a comprehensive overview of all the important concepts discussed"
366
+ ],
367
+ theme=gr.themes.Soft(),
368
+ allow_flagging="never"
369
  )
370
 
371
  demo.launch()