AmelC commited on
Commit
e7747d3
Β·
verified Β·
1 Parent(s): 58c3da6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +328 -0
app.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import os
2
  import re
3
  import json
@@ -246,3 +248,329 @@ demo = gr.Interface(
246
  )
247
 
248
  demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+
3
  import os
4
  import re
5
  import json
 
248
  )
249
 
250
  demo.launch()
251
+ '''
252
+
253
+
254
+ import os
255
+ import re
256
+ import json
257
+ import torch
258
+ import numpy as np
259
+ import logging
260
+ from typing import Dict, List, Tuple, Optional
261
+ from tqdm import tqdm
262
+ from pydantic import BaseModel
263
+ from transformers import (
264
+ AutoTokenizer,
265
+ AutoModelForSeq2SeqLM,
266
+ AutoModelForQuestionAnswering,
267
+ pipeline,
268
+ LogitsProcessor,
269
+ LogitsProcessorList,
270
+ PreTrainedModel,
271
+ PreTrainedTokenizer
272
+ )
273
+ from sentence_transformers import SentenceTransformer, CrossEncoder
274
+ from sklearn.feature_extraction.text import TfidfVectorizer
275
+ from rank_bm25 import BM25Okapi
276
+ import PyPDF2
277
+ from sklearn.cluster import KMeans
278
+ import spacy
279
+ import subprocess
280
+ import gradio as gr
281
+
282
+ logging.basicConfig(
283
+ level=logging.INFO,
284
+ format="%(asctime)s [%(levelname)s] %(message)s"
285
+ )
286
+
287
+ class ConfidenceCalibrator(LogitsProcessor):
288
+ def __init__(self, calibration_factor: float = 0.9):
289
+ self.calibration_factor = calibration_factor
290
+
291
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
292
+ return scores / self.calibration_factor
293
+
294
+ class DocumentResult(BaseModel):
295
+ content: str
296
+ confidence: float
297
+ source_page: int
298
+ supporting_evidence: List[str]
299
+
300
+ class OptimalModelSelector:
301
+ def __init__(self):
302
+ self.qa_models = {
303
+ "deberta-v3": ("deepset/deberta-v3-large-squad2", 0.87)
304
+ }
305
+ self.summarization_models = {
306
+ "bart": ("facebook/bart-large-cnn", 0.85)
307
+ }
308
+ self.current_models = {}
309
+
310
+ def get_best_model(self, task_type: str) -> Tuple[PreTrainedModel, PreTrainedTokenizer, float]:
311
+ model_map = self.qa_models if "qa" in task_type else self.summarization_models
312
+ best_model_name, best_score = max(model_map.items(), key=lambda x: x[1][1])
313
+ if best_model_name not in self.current_models:
314
+ tokenizer = AutoTokenizer.from_pretrained(model_map[best_model_name][0])
315
+ model = (AutoModelForQuestionAnswering if "qa" in task_type
316
+ else AutoModelForSeq2SeqLM).from_pretrained(model_map[best_model_name][0])
317
+ model = model.eval().half().to('cuda' if torch.cuda.is_available() else 'cpu')
318
+ self.current_models[best_model_name] = (model, tokenizer)
319
+ return *self.current_models[best_model_name], best_score
320
+
321
+ class PDFAugmentedRetriever:
322
+ def __init__(self, document_texts: List[str]):
323
+ self.documents = [(i, text) for i, text in enumerate(document_texts)]
324
+ self.bm25 = BM25Okapi([text.split() for _, text in self.documents])
325
+ self.encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
326
+ self.tfidf = TfidfVectorizer(stop_words='english').fit([text for _, text in self.documents])
327
+
328
+ def retrieve(self, query: str, top_k: int = 5) -> List[Tuple[int, str, float]]:
329
+ bm25_scores = self.bm25.get_scores(query.split())
330
+ semantic_scores = self.encoder.predict([(query, doc) for _, doc in self.documents])
331
+ combined_scores = 0.4 * bm25_scores + 0.6 * np.array(semantic_scores)
332
+ top_indices = np.argsort(combined_scores)[-top_k:][::-1]
333
+ return [(self.documents[i][0], self.documents[i][1], float(combined_scores[i]))
334
+ for i in top_indices]
335
+
336
+ class DetailedExplainer:
337
+ def __init__(self,
338
+ explanation_model: str = "google/flan-t5-large",
339
+ device: int = 0):
340
+ try:
341
+ self.nlp = spacy.load("en_core_web_sm")
342
+ except OSError:
343
+ subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"], check=True)
344
+ self.nlp = spacy.load("en_core_web_sm")
345
+ self.explainer = pipeline(
346
+ "text2text-generation",
347
+ model=explanation_model,
348
+ tokenizer=explanation_model,
349
+ device=device,
350
+ max_length=500,
351
+ max_new_tokens=800
352
+ )
353
+
354
+ def extract_concepts(self, text: str) -> list:
355
+ doc = self.nlp(text)
356
+ concepts = set()
357
+ for chunk in doc.noun_chunks:
358
+ if len(chunk) > 1 and not chunk.root.is_stop:
359
+ concepts.add(chunk.text.strip())
360
+ for ent in doc.ents:
361
+ if ent.label_ in ["PERSON", "ORG", "GPE", "NORP", "EVENT", "WORK_OF_ART"]:
362
+ concepts.add(ent.text.strip())
363
+ return list(concepts)
364
+
365
+ def explain_concept(self, concept: str, context: str, min_accuracy: float = 0.50) -> str:
366
+ prompt = (
367
+ f"The following sentence from a PDF is given \n{context}\n\n\nNow explain the concept '{concept}' mentioned above with at least {int(min_accuracy * 100)}% accuracy."
368
+ )
369
+ result = self.explainer(
370
+ prompt,
371
+ do_sample=False
372
+ )
373
+ return result[0]["generated_text"].strip()
374
+
375
+ def explain_text(self, text: str, context: str) -> dict:
376
+ concepts = self.extract_concepts(text)
377
+ explanations = {}
378
+ for concept in concepts:
379
+ explanations[concept] = self.explain_concept(concept, context)
380
+ return {"concepts": concepts, "explanations": explanations}
381
+
382
+ class AdvancedPDFAnalyzer:
383
+ def __init__(self):
384
+ self.logger = logging.getLogger("PDFAnalyzer")
385
+ self.model_selector = OptimalModelSelector()
386
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
387
+ self.qa_model, self.qa_tokenizer, _ = self.model_selector.get_best_model("qa")
388
+ self.qa_model = self.qa_model.to(self.device)
389
+ self.summarizer = pipeline(
390
+ "summarization",
391
+ model="facebook/bart-large-cnn",
392
+ device=0 if torch.cuda.is_available() else -1,
393
+ framework="pt"
394
+ )
395
+ self.logits_processor = LogitsProcessorList([
396
+ ConfidenceCalibrator(calibration_factor=0.85)
397
+ ])
398
+ self.detailed_explainer = DetailedExplainer(device=0 if torch.cuda.is_available() else -1)
399
+
400
+ def extract_text_with_metadata(self, file_path: str) -> List[Dict]:
401
+ documents = []
402
+ with open(file_path, 'rb') as f:
403
+ reader = PyPDF2.PdfReader(f)
404
+ for i, page in enumerate(reader.pages):
405
+ text = page.extract_text()
406
+ if not text or not text.strip():
407
+ continue
408
+ page_number = i + 1
409
+ metadata = {
410
+ 'source': os.path.basename(file_path),
411
+ 'page': page_number,
412
+ 'char_count': len(text),
413
+ 'word_count': len(text.split()),
414
+ }
415
+ documents.append({
416
+ 'content': self._clean_text(text),
417
+ 'metadata': metadata
418
+ })
419
+ if not documents:
420
+ raise ValueError("No extractable content found in PDF")
421
+ return documents
422
+
423
+ def _clean_text(self, text: str) -> str:
424
+ text = re.sub(r'[\x00-\x1F\x7F-\x9F]', ' ', text)
425
+ text = re.sub(r'\s+', ' ', text)
426
+ text = re.sub(r'(\w)-\s+(\w)', r'\1\2', text)
427
+ return text.strip()
428
+
429
+ def answer_question(self, question: str, documents: List[Dict]) -> Dict:
430
+ retriever = PDFAugmentedRetriever([doc['content'] for doc in documents])
431
+ relevant_contexts = retriever.retrieve(question, top_k=3)
432
+ answers = []
433
+
434
+ for page_idx, context, similarity_score in relevant_contexts:
435
+ inputs = self.qa_tokenizer(
436
+ question,
437
+ context,
438
+ add_special_tokens=True,
439
+ return_tensors="pt",
440
+ max_length=512,
441
+ truncation=True
442
+ )
443
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
444
+
445
+ with torch.no_grad():
446
+ outputs = self.qa_model(**inputs)
447
+ start_logits = outputs.start_logits
448
+ end_logits = outputs.end_logits
449
+
450
+ logits_processor = LogitsProcessorList([ConfidenceCalibrator()])
451
+ start_logits = logits_processor(inputs['input_ids'], start_logits)
452
+ end_logits = logits_processor(inputs['input_ids'], end_logits)
453
+
454
+ start_prob = torch.nn.functional.softmax(start_logits, dim=-1)
455
+ end_prob = torch.nn.functional.softmax(end_logits, dim=-1)
456
+
457
+ max_start_score, max_start_idx = torch.max(start_prob, dim=-1)
458
+ max_start_idx_int = max_start_idx.item()
459
+ max_end_score, max_end_idx = torch.max(end_prob[0, max_start_idx_int:], dim=-1)
460
+ max_end_idx_int = max_end_idx.item() + max_start_idx_int
461
+
462
+ confidence = float((max_start_score * max_end_score) * 0.9 * similarity_score)
463
+ answer_tokens = inputs["input_ids"][0][max_start_idx_int:max_end_idx_int + 1]
464
+ answer = self.qa_tokenizer.decode(answer_tokens, skip_special_tokens=True)
465
+
466
+ # Only generate explanations if we have a valid answer
467
+ explanations_result = {"concepts": [], "explanations": {}}
468
+ if answer and answer.strip():
469
+ try:
470
+ explanations_result = self.detailed_explainer.explain_text(answer, context)
471
+ except Exception as e:
472
+ self.logger.warning(f"Failed to generate explanations: {e}")
473
+
474
+ answers.append({
475
+ "answer": answer,
476
+ "confidence": confidence,
477
+ "context": context,
478
+ "page_number": documents[page_idx]['metadata']['page'],
479
+ "explanations": explanations_result
480
+ })
481
+
482
+ if not answers:
483
+ return {
484
+ "answer": "No confident answer found",
485
+ "confidence": 0.0,
486
+ "explanations": {"concepts": [], "explanations": {}},
487
+ "page_number": 0,
488
+ "context": ""
489
+ }
490
+
491
+ # Get the best answer based on confidence
492
+ best_answer = max(answers, key=lambda x: x['confidence'])
493
+
494
+ # FIXED: Always return the best answer dictionary, just modify the answer text if confidence is low
495
+ if best_answer['confidence'] < 0.3: # Lowered threshold to be more permissive
496
+ best_answer['answer'] = f"[Low Confidence] {best_answer['answer']}"
497
+
498
+ return best_answer
499
+
500
+ # Initialize analyzer (make sure to update the PDF path)
501
+ analyzer = AdvancedPDFAnalyzer()
502
+
503
+ # Global variable to store documents
504
+ documents = []
505
+
506
+ def load_pdf(file_path: str):
507
+ """Load PDF and extract documents"""
508
+ global documents
509
+ try:
510
+ documents = analyzer.extract_text_with_metadata(file_path)
511
+ return f"Successfully loaded PDF with {len(documents)} pages."
512
+ except Exception as e:
513
+ return f"Error loading PDF: {str(e)}"
514
+
515
+ def ask_question_gradio(question: str):
516
+ if not question.strip():
517
+ return "Please enter a valid question."
518
+
519
+ if not documents:
520
+ return "❌ No PDF loaded. Please load a PDF first."
521
+
522
+ try:
523
+ result = analyzer.answer_question(question, documents)
524
+
525
+ # Ensure we have the expected structure
526
+ answer = result.get('answer', 'No answer found')
527
+ confidence = result.get('confidence', 0.0)
528
+ page_number = result.get('page_number', 0)
529
+ explanations = result.get("explanations", {}).get("explanations", {})
530
+
531
+ # Format explanations
532
+ explanation_text = ""
533
+ if explanations:
534
+ explanation_text = "\n\n".join(
535
+ f"πŸ”Ή **{concept}**: {desc}"
536
+ for concept, desc in explanations.items()
537
+ if desc and desc.strip()
538
+ )
539
+
540
+ # Build response
541
+ response_parts = [
542
+ f"πŸ“Œ **Answer**: {answer}",
543
+ f"πŸ”’ **Confidence**: {confidence:.2f}",
544
+ f"πŸ“„ **Page**: {page_number}"
545
+ ]
546
+
547
+ if explanation_text:
548
+ response_parts.append(f"πŸ“˜ **Explanations**:\n{explanation_text}")
549
+
550
+ return "\n\n".join(response_parts)
551
+
552
+ except Exception as e:
553
+ return f"❌ Error: {str(e)}"
554
+
555
+ # Load your PDF here - update the path to your actual PDF file
556
+ pdf_path = "example.pdf"
557
+ if os.path.exists(pdf_path):
558
+ load_result = load_pdf(pdf_path)
559
+ print(load_result)
560
+ else:
561
+ print(f"PDF file '{pdf_path}' not found. Please update the path.")
562
+
563
+ demo = gr.Interface(
564
+ fn=ask_question_gradio,
565
+ inputs=gr.Textbox(label="Ask a question about the PDF", placeholder="Type your question here..."),
566
+ outputs=gr.Markdown(label="Answer"),
567
+ title="Quandans AI - Ask Questions",
568
+ description="Ask a question based on the document loaded in this system.",
569
+ examples=[
570
+ "What is the main topic of this document?",
571
+ "Summarize the key points from page 1",
572
+ "What are the conclusions mentioned?"
573
+ ]
574
+ )
575
+
576
+ demo.launch()