AmelC commited on
Commit
fde38c3
Β·
verified Β·
1 Parent(s): 238741c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +255 -0
app.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import gradio as gr
2
  import os
3
  import re
@@ -521,4 +523,257 @@ demo = gr.Interface(
521
  description="Enter a question based on the loaded PDF document. The system will provide an answer with confidence and concept explanations."
522
  )
523
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
524
  demo.launch()
 
1
+ '''
2
+
3
  import gradio as gr
4
  import os
5
  import re
 
523
  description="Enter a question based on the loaded PDF document. The system will provide an answer with confidence and concept explanations."
524
  )
525
 
526
+ demo.launch()
527
+
528
+ '''
529
+
530
+ import os
531
+ import re
532
+ import json
533
+ import torch
534
+ import numpy as np
535
+ import logging
536
+ from typing import Dict, List, Tuple, Optional
537
+ from tqdm import tqdm
538
+ from pydantic import BaseModel
539
+ from transformers import (
540
+ AutoTokenizer,
541
+ AutoModelForSeq2SeqLM,
542
+ AutoModelForQuestionAnswering,
543
+ pipeline,
544
+ LogitsProcessor,
545
+ LogitsProcessorList,
546
+ PreTrainedModel,
547
+ PreTrainedTokenizer
548
+ )
549
+ from sentence_transformers import SentenceTransformer, CrossEncoder
550
+ from sklearn.feature_extraction.text import TfidfVectorizer
551
+ from rank_bm25 import BM25Okapi
552
+ import PyPDF2
553
+ from sklearn.cluster import KMeans
554
+ import spacy
555
+ import subprocess
556
+ import gradio as gr
557
+
558
+ logging.basicConfig(
559
+ level=logging.INFO,
560
+ format="%(asctime)s [%(levelname)s] %(message)s"
561
+ )
562
+
563
+ class ConfidenceCalibrator(LogitsProcessor):
564
+ def __init__(self, calibration_factor: float = 0.9):
565
+ self.calibration_factor = calibration_factor
566
+
567
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
568
+ return scores / self.calibration_factor
569
+
570
+ class DocumentResult(BaseModel):
571
+ content: str
572
+ confidence: float
573
+ source_page: int
574
+ supporting_evidence: List[str]
575
+
576
+ class OptimalModelSelector:
577
+ def __init__(self):
578
+ self.qa_models = {
579
+ "deberta-v3": ("deepset/deberta-v3-large-squad2", 0.87)
580
+ }
581
+ self.summarization_models = {
582
+ "bart": ("facebook/bart-large-cnn", 0.85)
583
+ }
584
+ self.current_models = {}
585
+
586
+ def get_best_model(self, task_type: str) -> Tuple[PreTrainedModel, PreTrainedTokenizer, float]:
587
+ model_map = self.qa_models if "qa" in task_type else self.summarization_models
588
+ best_model_name, best_score = max(model_map.items(), key=lambda x: x[1][1])
589
+ if best_model_name not in self.current_models:
590
+ tokenizer = AutoTokenizer.from_pretrained(model_map[best_model_name][0])
591
+ model = (AutoModelForQuestionAnswering if "qa" in task_type
592
+ else AutoModelForSeq2SeqLM).from_pretrained(model_map[best_model_name][0])
593
+ model = model.eval().half().to('cuda' if torch.cuda.is_available() else 'cpu')
594
+ self.current_models[best_model_name] = (model, tokenizer)
595
+ return *self.current_models[best_model_name], best_score
596
+
597
+ class PDFAugmentedRetriever:
598
+ def __init__(self, document_texts: List[str]):
599
+ self.documents = [(i, text) for i, text in enumerate(document_texts)]
600
+ self.bm25 = BM25Okapi([text.split() for _, text in self.documents])
601
+ self.encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
602
+ self.tfidf = TfidfVectorizer(stop_words='english').fit([text for _, text in self.documents])
603
+
604
+ def retrieve(self, query: str, top_k: int = 5) -> List[Tuple[int, str, float]]:
605
+ bm25_scores = self.bm25.get_scores(query.split())
606
+ semantic_scores = self.encoder.predict([(query, doc) for _, doc in self.documents])
607
+ combined_scores = 0.4 * bm25_scores + 0.6 * np.array(semantic_scores)
608
+ top_indices = np.argsort(combined_scores)[-top_k:][::-1]
609
+ return [(self.documents[i][0], self.documents[i][1], float(combined_scores[i]))
610
+ for i in top_indices]
611
+
612
+ class DetailedExplainer:
613
+ def __init__(self,
614
+ explanation_model: str = "google/flan-t5-large",
615
+ device: int = 0):
616
+ try:
617
+ self.nlp = spacy.load("en_core_web_sm")
618
+ except OSError:
619
+ subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"], check=True)
620
+ self.nlp = spacy.load("en_core_web_sm")
621
+ self.explainer = pipeline(
622
+ "text2text-generation",
623
+ model=explanation_model,
624
+ tokenizer=explanation_model,
625
+ device=device
626
+ )
627
+
628
+ def extract_concepts(self, text: str) -> list:
629
+ doc = self.nlp(text)
630
+ concepts = set()
631
+ for chunk in doc.noun_chunks:
632
+ if len(chunk) > 1 and not chunk.root.is_stop:
633
+ concepts.add(chunk.text.strip())
634
+ for ent in doc.ents:
635
+ if ent.label_ in ["PERSON", "ORG", "GPE", "NORP", "EVENT", "WORK_OF_ART"]:
636
+ concepts.add(ent.text.strip())
637
+ return list(concepts)
638
+
639
+ def explain_concept(self, concept: str, context: str, min_accuracy: float = 0.50) -> str:
640
+ prompt = (
641
+ f"Explain the concept '{concept}' in depth using the following context. "
642
+ f"Aim for at least {int(min_accuracy * 100)}% accuracy."
643
+ f"\nContext:\n{context}\n"
644
+ )
645
+ result = self.explainer(
646
+ prompt,
647
+ max_length=200,
648
+ min_length=80,
649
+ do_sample=False
650
+ )
651
+ return result[0]["generated_text"].strip()
652
+
653
+ def explain_text(self, text: str, context: str) -> dict:
654
+ concepts = self.extract_concepts(text)
655
+ explanations = {}
656
+ for concept in concepts:
657
+ explanations[concept] = self.explain_concept(concept, context)
658
+ return {"concepts": concepts, "explanations": explanations}
659
+
660
+ class AdvancedPDFAnalyzer:
661
+ def __init__(self):
662
+ self.logger = logging.getLogger("PDFAnalyzer")
663
+ self.model_selector = OptimalModelSelector()
664
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
665
+ self.qa_model, self.qa_tokenizer, _ = self.model_selector.get_best_model("qa")
666
+ self.qa_model = self.qa_model.to(self.device)
667
+ self.summarizer = pipeline(
668
+ "summarization",
669
+ model="facebook/bart-large-cnn",
670
+ device=0 if torch.cuda.is_available() else -1,
671
+ framework="pt"
672
+ )
673
+ self.logits_processor = LogitsProcessorList([
674
+ ConfidenceCalibrator(calibration_factor=0.85)
675
+ ])
676
+ self.detailed_explainer = DetailedExplainer(device=0 if torch.cuda.is_available() else -1)
677
+
678
+ def extract_text_with_metadata(self, file_path: str) -> List[Dict]:
679
+ documents = []
680
+ with open(file_path, 'rb') as f:
681
+ reader = PyPDF2.PdfReader(f)
682
+ for i, page in enumerate(reader.pages):
683
+ text = page.extract_text()
684
+ if not text or not text.strip():
685
+ continue
686
+ page_number = i + 1
687
+ metadata = {
688
+ 'source': os.path.basename(file_path),
689
+ 'page': page_number,
690
+ 'char_count': len(text),
691
+ 'word_count': len(text.split()),
692
+ }
693
+ documents.append({
694
+ 'content': self._clean_text(text),
695
+ 'metadata': metadata
696
+ })
697
+ if not documents:
698
+ raise ValueError("No extractable content found in PDF")
699
+ return documents
700
+
701
+ def _clean_text(self, text: str) -> str:
702
+ text = re.sub(r'[\x00-\x1F\x7F-\x9F]', ' ', text)
703
+ text = re.sub(r'\s+', ' ', text)
704
+ text = re.sub(r'(\w)-\s+(\w)', r'\1\2', text)
705
+ return text.strip()
706
+
707
+ def answer_question(self, question: str, documents: List[Dict]) -> Dict:
708
+ retriever = PDFAugmentedRetriever([doc['content'] for doc in documents])
709
+ relevant_contexts = retriever.retrieve(question, top_k=3)
710
+ answers = []
711
+ for page_idx, context, similarity_score in relevant_contexts:
712
+ inputs = self.qa_tokenizer(
713
+ question,
714
+ context,
715
+ add_special_tokens=True,
716
+ return_tensors="pt",
717
+ max_length=512,
718
+ truncation="only_second"
719
+ )
720
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
721
+ with torch.no_grad():
722
+ outputs = self.qa_model(**inputs)
723
+ start_logits = outputs.start_logits
724
+ end_logits = outputs.end_logits
725
+ logits_processor = LogitsProcessorList([ConfidenceCalibrator()])
726
+ start_logits = logits_processor(inputs['input_ids'], start_logits)
727
+ end_logits = logits_processor(inputs['input_ids'], end_logits)
728
+ start_prob = torch.nn.functional.softmax(start_logits, dim=-1)
729
+ end_prob = torch.nn.functional.softmax(end_logits, dim=-1)
730
+ max_start_score, max_start_idx = torch.max(start_prob, dim=-1)
731
+ max_start_idx_int = max_start_idx.item()
732
+ max_end_score, max_end_idx = torch.max(end_prob[0, max_start_idx_int:], dim=-1)
733
+ max_end_idx_int = max_end_idx.item() + max_start_idx_int
734
+ confidence = float((max_start_score * max_end_score) * 0.9 * similarity_score)
735
+ answer_tokens = inputs["input_ids"][0][max_start_idx_int:max_end_idx_int + 1]
736
+ answer = self.qa_tokenizer.decode(answer_tokens, skip_special_tokens=True)
737
+ explanations_result = self.detailed_explainer.explain_text(answer, context)
738
+ answers.append({
739
+ "answer": answer,
740
+ "confidence": confidence,
741
+ "context": context,
742
+ "page_number": documents[page_idx]['metadata']['page'],
743
+ "explanations": explanations_result
744
+ })
745
+ if not answers:
746
+ return {"answer": "No confident answer found", "confidence": 0.0, "explanations": {}}
747
+ best_answer = max(answers, key=lambda x: x['confidence'])
748
+ if best_answer['confidence'] < 0.85:
749
+ best_answer['answer'] = f"[Low Confidence] {best_answer['answer']}"
750
+ return best_answer
751
+
752
+ # Instantiate analyzer once
753
+ analyzer = AdvancedPDFAnalyzer()
754
+ documents = analyzer.extract_text_with_metadata("example.pdf")
755
+
756
+ def ask_question_gradio(question: str):
757
+ if not question.strip():
758
+ return "Please enter a valid question."
759
+ try:
760
+ result = analyzer.answer_question(question, documents)
761
+ answer = result['answer']
762
+ confidence = result['confidence']
763
+ explanation = "\n\n".join(
764
+ f"πŸ”Ή {concept}: {desc}"
765
+ for concept, desc in result.get("explanations", {}).get("explanations", {}).items()
766
+ )
767
+ return f"πŸ“Œ **Answer**: {answer}\n\nπŸ”’ **Confidence**: {confidence:.2f}\n\nπŸ“˜ **Explanations**:\n{explanation}"
768
+ except Exception as e:
769
+ return f"❌ Error: {str(e)}"
770
+
771
+ demo = gr.Interface(
772
+ fn=ask_question_gradio,
773
+ inputs=gr.Textbox(label="Ask a question about the PDF"),
774
+ outputs=gr.Markdown(label="Answer"),
775
+ title="Quandans AI - Ask Questions",
776
+ description="Ask a question based on the document loaded in this system."
777
+ )
778
+
779
  demo.launch()