pdfassistant / app.py
priyanshu23456's picture
Update app.py
b3305c3 verified
raw
history blame
4.81 kB
from flask import Flask, request, jsonify
from werkzeug.utils import secure_filename
from flask_cors import CORS # βœ… Add this line
import os
import torch
import fitz # PyMuPDF
import pytesseract
from pdf2image import convert_from_path
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
# Fix caching issue on Hugging Face Spaces
os.environ["TRANSFORMERS_CACHE"] = "/tmp"
os.environ["HF_HOME"] = "/tmp"
os.environ["XDG_CACHE_HOME"] = "/tmp"
app = Flask(__name__)
CORS(app) # βœ… Enable CORS for all routes
UPLOAD_FOLDER = "/tmp/uploads"
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
# βœ… OCR for scanned PDFs
def ocr_pdf(pdf_path):
images = convert_from_path(pdf_path)
text = ""
for img in images:
text += pytesseract.image_to_string(img)
return text
# βœ… Extract text
def extract_text(pdf_path):
doc = fitz.open(pdf_path)
text = ""
for page in doc:
text += page.get_text()
if len(text.strip()) < 50:
text = ocr_pdf(pdf_path)
return text
# βœ… Split into chunks
def split_into_chunks(text, max_tokens=300, overlap=50):
sentences = text.split('.')
chunks, current = [], ''
for sentence in sentences:
sentence = sentence.strip() + '.'
if len(current) + len(sentence) < max_tokens:
current += sentence
else:
chunks.append(current.strip())
words = current.split()
if len(words) > overlap:
current = ' '.join(words[-overlap:]) + ' ' + sentence
else:
current = sentence
if current:
chunks.append(current.strip())
return chunks
# βœ… Setup FAISS
def setup_faiss(chunks):
embedder = SentenceTransformer("all-MiniLM-L6-v2")
embeddings = embedder.encode(chunks)
dim = embeddings.shape[1]
index = faiss.IndexFlatL2(dim)
index.add(embeddings)
return index, embeddings, chunks
# βœ… QA pipeline
def answer_with_qa_pipeline(chunks, question):
qa_pipeline = pipeline(
"question-answering",
model="distilbert-base-cased-distilled-squad",
tokenizer="distilbert-base-cased",
device=0 if device == "cuda" else -1
)
context = " ".join(chunks[:5])
try:
result = qa_pipeline(question=question, context=context)
return result["answer"]
except:
return ""
# βœ… Generation fallback
def answer_with_generation(index, embeddings, chunks, question):
tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
model = AutoModelForCausalLM.from_pretrained("distilgpt2").to(device)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id
embedder = SentenceTransformer("all-MiniLM-L6-v2")
q_embedding = embedder.encode([question])
_, top_k_indices = index.search(q_embedding, k=3)
relevant_chunks = [chunks[i] for i in top_k_indices[0]]
context = " ".join(relevant_chunks)
prompt = f"Answer the following question based on this information:\n\nInformation: {context}\n\nQuestion: {question}\n\nDetailed answer:"
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(device)
output = model.generate(
**inputs,
max_new_tokens=300,
temperature=0.7,
top_p=0.9,
do_sample=True,
num_beams=3,
no_repeat_ngram_size=2
)
answer = tokenizer.decode(output[0], skip_special_tokens=True)
if "Detailed answer:" in answer:
return answer.split("Detailed answer:")[-1].strip()
return answer.strip()
# βœ… API route
@app.route('/')
def home():
return jsonify({"message": "PDF QA API is running!"})
@app.route('/ask', methods=['POST'])
def ask():
file = request.files.get("pdf")
question = request.form.get("question", "")
if not file or not question:
return jsonify({"error": "Both PDF file and question are required"}), 400
filename = secure_filename(file.filename)
filepath = os.path.join(UPLOAD_FOLDER, filename)
file.save(filepath)
try:
# 🧠 Process PDF and generate answer
text = extract_text(filepath)
chunks = split_into_chunks(text)
answer = answer_with_qa_pipeline(chunks, question)
if len(answer.strip()) < 20:
index, embeddings, chunks = setup_faiss(chunks)
answer = answer_with_generation(index, embeddings, chunks, question)
return jsonify({"answer": answer})
except Exception as e:
return jsonify({"error": str(e)}), 500
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860)