Update app.py
Browse files
app.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
import os
|
2 |
import re
|
3 |
import json
|
@@ -246,3 +248,329 @@ demo = gr.Interface(
|
|
246 |
)
|
247 |
|
248 |
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
|
3 |
import os
|
4 |
import re
|
5 |
import json
|
|
|
248 |
)
|
249 |
|
250 |
demo.launch()
|
251 |
+
'''
|
252 |
+
|
253 |
+
|
254 |
+
import os
|
255 |
+
import re
|
256 |
+
import json
|
257 |
+
import torch
|
258 |
+
import numpy as np
|
259 |
+
import logging
|
260 |
+
from typing import Dict, List, Tuple, Optional
|
261 |
+
from tqdm import tqdm
|
262 |
+
from pydantic import BaseModel
|
263 |
+
from transformers import (
|
264 |
+
AutoTokenizer,
|
265 |
+
AutoModelForSeq2SeqLM,
|
266 |
+
AutoModelForQuestionAnswering,
|
267 |
+
pipeline,
|
268 |
+
LogitsProcessor,
|
269 |
+
LogitsProcessorList,
|
270 |
+
PreTrainedModel,
|
271 |
+
PreTrainedTokenizer
|
272 |
+
)
|
273 |
+
from sentence_transformers import SentenceTransformer, CrossEncoder
|
274 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
275 |
+
from rank_bm25 import BM25Okapi
|
276 |
+
import PyPDF2
|
277 |
+
from sklearn.cluster import KMeans
|
278 |
+
import spacy
|
279 |
+
import subprocess
|
280 |
+
import gradio as gr
|
281 |
+
|
282 |
+
logging.basicConfig(
|
283 |
+
level=logging.INFO,
|
284 |
+
format="%(asctime)s [%(levelname)s] %(message)s"
|
285 |
+
)
|
286 |
+
|
287 |
+
class ConfidenceCalibrator(LogitsProcessor):
|
288 |
+
def __init__(self, calibration_factor: float = 0.9):
|
289 |
+
self.calibration_factor = calibration_factor
|
290 |
+
|
291 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
292 |
+
return scores / self.calibration_factor
|
293 |
+
|
294 |
+
class DocumentResult(BaseModel):
|
295 |
+
content: str
|
296 |
+
confidence: float
|
297 |
+
source_page: int
|
298 |
+
supporting_evidence: List[str]
|
299 |
+
|
300 |
+
class OptimalModelSelector:
|
301 |
+
def __init__(self):
|
302 |
+
self.qa_models = {
|
303 |
+
"deberta-v3": ("deepset/deberta-v3-large-squad2", 0.87)
|
304 |
+
}
|
305 |
+
self.summarization_models = {
|
306 |
+
"bart": ("facebook/bart-large-cnn", 0.85)
|
307 |
+
}
|
308 |
+
self.current_models = {}
|
309 |
+
|
310 |
+
def get_best_model(self, task_type: str) -> Tuple[PreTrainedModel, PreTrainedTokenizer, float]:
|
311 |
+
model_map = self.qa_models if "qa" in task_type else self.summarization_models
|
312 |
+
best_model_name, best_score = max(model_map.items(), key=lambda x: x[1][1])
|
313 |
+
if best_model_name not in self.current_models:
|
314 |
+
tokenizer = AutoTokenizer.from_pretrained(model_map[best_model_name][0])
|
315 |
+
model = (AutoModelForQuestionAnswering if "qa" in task_type
|
316 |
+
else AutoModelForSeq2SeqLM).from_pretrained(model_map[best_model_name][0])
|
317 |
+
model = model.eval().half().to('cuda' if torch.cuda.is_available() else 'cpu')
|
318 |
+
self.current_models[best_model_name] = (model, tokenizer)
|
319 |
+
return *self.current_models[best_model_name], best_score
|
320 |
+
|
321 |
+
class PDFAugmentedRetriever:
|
322 |
+
def __init__(self, document_texts: List[str]):
|
323 |
+
self.documents = [(i, text) for i, text in enumerate(document_texts)]
|
324 |
+
self.bm25 = BM25Okapi([text.split() for _, text in self.documents])
|
325 |
+
self.encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
|
326 |
+
self.tfidf = TfidfVectorizer(stop_words='english').fit([text for _, text in self.documents])
|
327 |
+
|
328 |
+
def retrieve(self, query: str, top_k: int = 5) -> List[Tuple[int, str, float]]:
|
329 |
+
bm25_scores = self.bm25.get_scores(query.split())
|
330 |
+
semantic_scores = self.encoder.predict([(query, doc) for _, doc in self.documents])
|
331 |
+
combined_scores = 0.4 * bm25_scores + 0.6 * np.array(semantic_scores)
|
332 |
+
top_indices = np.argsort(combined_scores)[-top_k:][::-1]
|
333 |
+
return [(self.documents[i][0], self.documents[i][1], float(combined_scores[i]))
|
334 |
+
for i in top_indices]
|
335 |
+
|
336 |
+
class DetailedExplainer:
|
337 |
+
def __init__(self,
|
338 |
+
explanation_model: str = "google/flan-t5-large",
|
339 |
+
device: int = 0):
|
340 |
+
try:
|
341 |
+
self.nlp = spacy.load("en_core_web_sm")
|
342 |
+
except OSError:
|
343 |
+
subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"], check=True)
|
344 |
+
self.nlp = spacy.load("en_core_web_sm")
|
345 |
+
self.explainer = pipeline(
|
346 |
+
"text2text-generation",
|
347 |
+
model=explanation_model,
|
348 |
+
tokenizer=explanation_model,
|
349 |
+
device=device,
|
350 |
+
max_length=500,
|
351 |
+
max_new_tokens=800
|
352 |
+
)
|
353 |
+
|
354 |
+
def extract_concepts(self, text: str) -> list:
|
355 |
+
doc = self.nlp(text)
|
356 |
+
concepts = set()
|
357 |
+
for chunk in doc.noun_chunks:
|
358 |
+
if len(chunk) > 1 and not chunk.root.is_stop:
|
359 |
+
concepts.add(chunk.text.strip())
|
360 |
+
for ent in doc.ents:
|
361 |
+
if ent.label_ in ["PERSON", "ORG", "GPE", "NORP", "EVENT", "WORK_OF_ART"]:
|
362 |
+
concepts.add(ent.text.strip())
|
363 |
+
return list(concepts)
|
364 |
+
|
365 |
+
def explain_concept(self, concept: str, context: str, min_accuracy: float = 0.50) -> str:
|
366 |
+
prompt = (
|
367 |
+
f"The following sentence from a PDF is given \n{context}\n\n\nNow explain the concept '{concept}' mentioned above with at least {int(min_accuracy * 100)}% accuracy."
|
368 |
+
)
|
369 |
+
result = self.explainer(
|
370 |
+
prompt,
|
371 |
+
do_sample=False
|
372 |
+
)
|
373 |
+
return result[0]["generated_text"].strip()
|
374 |
+
|
375 |
+
def explain_text(self, text: str, context: str) -> dict:
|
376 |
+
concepts = self.extract_concepts(text)
|
377 |
+
explanations = {}
|
378 |
+
for concept in concepts:
|
379 |
+
explanations[concept] = self.explain_concept(concept, context)
|
380 |
+
return {"concepts": concepts, "explanations": explanations}
|
381 |
+
|
382 |
+
class AdvancedPDFAnalyzer:
|
383 |
+
def __init__(self):
|
384 |
+
self.logger = logging.getLogger("PDFAnalyzer")
|
385 |
+
self.model_selector = OptimalModelSelector()
|
386 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
387 |
+
self.qa_model, self.qa_tokenizer, _ = self.model_selector.get_best_model("qa")
|
388 |
+
self.qa_model = self.qa_model.to(self.device)
|
389 |
+
self.summarizer = pipeline(
|
390 |
+
"summarization",
|
391 |
+
model="facebook/bart-large-cnn",
|
392 |
+
device=0 if torch.cuda.is_available() else -1,
|
393 |
+
framework="pt"
|
394 |
+
)
|
395 |
+
self.logits_processor = LogitsProcessorList([
|
396 |
+
ConfidenceCalibrator(calibration_factor=0.85)
|
397 |
+
])
|
398 |
+
self.detailed_explainer = DetailedExplainer(device=0 if torch.cuda.is_available() else -1)
|
399 |
+
|
400 |
+
def extract_text_with_metadata(self, file_path: str) -> List[Dict]:
|
401 |
+
documents = []
|
402 |
+
with open(file_path, 'rb') as f:
|
403 |
+
reader = PyPDF2.PdfReader(f)
|
404 |
+
for i, page in enumerate(reader.pages):
|
405 |
+
text = page.extract_text()
|
406 |
+
if not text or not text.strip():
|
407 |
+
continue
|
408 |
+
page_number = i + 1
|
409 |
+
metadata = {
|
410 |
+
'source': os.path.basename(file_path),
|
411 |
+
'page': page_number,
|
412 |
+
'char_count': len(text),
|
413 |
+
'word_count': len(text.split()),
|
414 |
+
}
|
415 |
+
documents.append({
|
416 |
+
'content': self._clean_text(text),
|
417 |
+
'metadata': metadata
|
418 |
+
})
|
419 |
+
if not documents:
|
420 |
+
raise ValueError("No extractable content found in PDF")
|
421 |
+
return documents
|
422 |
+
|
423 |
+
def _clean_text(self, text: str) -> str:
|
424 |
+
text = re.sub(r'[\x00-\x1F\x7F-\x9F]', ' ', text)
|
425 |
+
text = re.sub(r'\s+', ' ', text)
|
426 |
+
text = re.sub(r'(\w)-\s+(\w)', r'\1\2', text)
|
427 |
+
return text.strip()
|
428 |
+
|
429 |
+
def answer_question(self, question: str, documents: List[Dict]) -> Dict:
|
430 |
+
retriever = PDFAugmentedRetriever([doc['content'] for doc in documents])
|
431 |
+
relevant_contexts = retriever.retrieve(question, top_k=3)
|
432 |
+
answers = []
|
433 |
+
|
434 |
+
for page_idx, context, similarity_score in relevant_contexts:
|
435 |
+
inputs = self.qa_tokenizer(
|
436 |
+
question,
|
437 |
+
context,
|
438 |
+
add_special_tokens=True,
|
439 |
+
return_tensors="pt",
|
440 |
+
max_length=512,
|
441 |
+
truncation=True
|
442 |
+
)
|
443 |
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
444 |
+
|
445 |
+
with torch.no_grad():
|
446 |
+
outputs = self.qa_model(**inputs)
|
447 |
+
start_logits = outputs.start_logits
|
448 |
+
end_logits = outputs.end_logits
|
449 |
+
|
450 |
+
logits_processor = LogitsProcessorList([ConfidenceCalibrator()])
|
451 |
+
start_logits = logits_processor(inputs['input_ids'], start_logits)
|
452 |
+
end_logits = logits_processor(inputs['input_ids'], end_logits)
|
453 |
+
|
454 |
+
start_prob = torch.nn.functional.softmax(start_logits, dim=-1)
|
455 |
+
end_prob = torch.nn.functional.softmax(end_logits, dim=-1)
|
456 |
+
|
457 |
+
max_start_score, max_start_idx = torch.max(start_prob, dim=-1)
|
458 |
+
max_start_idx_int = max_start_idx.item()
|
459 |
+
max_end_score, max_end_idx = torch.max(end_prob[0, max_start_idx_int:], dim=-1)
|
460 |
+
max_end_idx_int = max_end_idx.item() + max_start_idx_int
|
461 |
+
|
462 |
+
confidence = float((max_start_score * max_end_score) * 0.9 * similarity_score)
|
463 |
+
answer_tokens = inputs["input_ids"][0][max_start_idx_int:max_end_idx_int + 1]
|
464 |
+
answer = self.qa_tokenizer.decode(answer_tokens, skip_special_tokens=True)
|
465 |
+
|
466 |
+
# Only generate explanations if we have a valid answer
|
467 |
+
explanations_result = {"concepts": [], "explanations": {}}
|
468 |
+
if answer and answer.strip():
|
469 |
+
try:
|
470 |
+
explanations_result = self.detailed_explainer.explain_text(answer, context)
|
471 |
+
except Exception as e:
|
472 |
+
self.logger.warning(f"Failed to generate explanations: {e}")
|
473 |
+
|
474 |
+
answers.append({
|
475 |
+
"answer": answer,
|
476 |
+
"confidence": confidence,
|
477 |
+
"context": context,
|
478 |
+
"page_number": documents[page_idx]['metadata']['page'],
|
479 |
+
"explanations": explanations_result
|
480 |
+
})
|
481 |
+
|
482 |
+
if not answers:
|
483 |
+
return {
|
484 |
+
"answer": "No confident answer found",
|
485 |
+
"confidence": 0.0,
|
486 |
+
"explanations": {"concepts": [], "explanations": {}},
|
487 |
+
"page_number": 0,
|
488 |
+
"context": ""
|
489 |
+
}
|
490 |
+
|
491 |
+
# Get the best answer based on confidence
|
492 |
+
best_answer = max(answers, key=lambda x: x['confidence'])
|
493 |
+
|
494 |
+
# FIXED: Always return the best answer dictionary, just modify the answer text if confidence is low
|
495 |
+
if best_answer['confidence'] < 0.3: # Lowered threshold to be more permissive
|
496 |
+
best_answer['answer'] = f"[Low Confidence] {best_answer['answer']}"
|
497 |
+
|
498 |
+
return best_answer
|
499 |
+
|
500 |
+
# Initialize analyzer (make sure to update the PDF path)
|
501 |
+
analyzer = AdvancedPDFAnalyzer()
|
502 |
+
|
503 |
+
# Global variable to store documents
|
504 |
+
documents = []
|
505 |
+
|
506 |
+
def load_pdf(file_path: str):
|
507 |
+
"""Load PDF and extract documents"""
|
508 |
+
global documents
|
509 |
+
try:
|
510 |
+
documents = analyzer.extract_text_with_metadata(file_path)
|
511 |
+
return f"Successfully loaded PDF with {len(documents)} pages."
|
512 |
+
except Exception as e:
|
513 |
+
return f"Error loading PDF: {str(e)}"
|
514 |
+
|
515 |
+
def ask_question_gradio(question: str):
|
516 |
+
if not question.strip():
|
517 |
+
return "Please enter a valid question."
|
518 |
+
|
519 |
+
if not documents:
|
520 |
+
return "β No PDF loaded. Please load a PDF first."
|
521 |
+
|
522 |
+
try:
|
523 |
+
result = analyzer.answer_question(question, documents)
|
524 |
+
|
525 |
+
# Ensure we have the expected structure
|
526 |
+
answer = result.get('answer', 'No answer found')
|
527 |
+
confidence = result.get('confidence', 0.0)
|
528 |
+
page_number = result.get('page_number', 0)
|
529 |
+
explanations = result.get("explanations", {}).get("explanations", {})
|
530 |
+
|
531 |
+
# Format explanations
|
532 |
+
explanation_text = ""
|
533 |
+
if explanations:
|
534 |
+
explanation_text = "\n\n".join(
|
535 |
+
f"πΉ **{concept}**: {desc}"
|
536 |
+
for concept, desc in explanations.items()
|
537 |
+
if desc and desc.strip()
|
538 |
+
)
|
539 |
+
|
540 |
+
# Build response
|
541 |
+
response_parts = [
|
542 |
+
f"π **Answer**: {answer}",
|
543 |
+
f"π **Confidence**: {confidence:.2f}",
|
544 |
+
f"π **Page**: {page_number}"
|
545 |
+
]
|
546 |
+
|
547 |
+
if explanation_text:
|
548 |
+
response_parts.append(f"π **Explanations**:\n{explanation_text}")
|
549 |
+
|
550 |
+
return "\n\n".join(response_parts)
|
551 |
+
|
552 |
+
except Exception as e:
|
553 |
+
return f"β Error: {str(e)}"
|
554 |
+
|
555 |
+
# Load your PDF here - update the path to your actual PDF file
|
556 |
+
pdf_path = "example.pdf"
|
557 |
+
if os.path.exists(pdf_path):
|
558 |
+
load_result = load_pdf(pdf_path)
|
559 |
+
print(load_result)
|
560 |
+
else:
|
561 |
+
print(f"PDF file '{pdf_path}' not found. Please update the path.")
|
562 |
+
|
563 |
+
demo = gr.Interface(
|
564 |
+
fn=ask_question_gradio,
|
565 |
+
inputs=gr.Textbox(label="Ask a question about the PDF", placeholder="Type your question here..."),
|
566 |
+
outputs=gr.Markdown(label="Answer"),
|
567 |
+
title="Quandans AI - Ask Questions",
|
568 |
+
description="Ask a question based on the document loaded in this system.",
|
569 |
+
examples=[
|
570 |
+
"What is the main topic of this document?",
|
571 |
+
"Summarize the key points from page 1",
|
572 |
+
"What are the conclusions mentioned?"
|
573 |
+
]
|
574 |
+
)
|
575 |
+
|
576 |
+
demo.launch()
|