Spaces:
Sleeping
Sleeping
File size: 7,682 Bytes
c971d0d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 |
# app.py
import os
import io
import tempfile
import fitz # PyMuPDF
import pytesseract
from pdf2image import convert_from_bytes, convert_from_path
import numpy as np
import faiss
import torch
from flask import Flask, request, jsonify
from flask_cors import CORS
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer
app = Flask(__name__)
CORS(app) # Enable CORS for cross-origin requests
# Device setup
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load models at startup (only once)
try:
print("Loading models...")
# Embedding model for semantic search
embedding_model = SentenceTransformer("all-MiniLM-L6-v2").to(device)
# QA pipeline for direct question answering
qa_pipeline = pipeline(
"question-answering",
model="distilbert-base-cased-distilled-squad",
tokenizer="distilbert-base-cased",
device=0 if device == "cuda" else -1
)
# Generation model for more detailed responses
gen_model_name = "distilgpt2" # Lightweight model suitable for free tier
gen_tokenizer = AutoTokenizer.from_pretrained(gen_model_name)
gen_model = AutoModelForCausalLM.from_pretrained(gen_model_name).to(device)
# Ensure pad token is set for the tokenizer
if gen_tokenizer.pad_token is None:
gen_tokenizer.pad_token = gen_tokenizer.eos_token
gen_model.config.pad_token_id = gen_model.config.eos_token_id
print("β
Models loaded successfully")
except Exception as e:
print(f"β Error loading models: {e}")
raise
# OCR fallback for scanned PDFs
def ocr_pdf(pdf_bytes):
try:
images = convert_from_bytes(pdf_bytes)
text = ""
for img in images:
text += pytesseract.image_to_string(img)
return text
except Exception as e:
print(f"β OCR error: {e}")
return ""
# Text extraction from standard PDFs
def extract_text(pdf_bytes):
try:
doc = fitz.open(stream=io.BytesIO(pdf_bytes), filetype="pdf")
text = ""
for page in doc:
text += page.get_text()
if len(text.strip()) < 50:
print("β οΈ Not enough text, using OCR fallback...")
text = ocr_pdf(pdf_bytes)
print("β
Text extraction complete")
return text
except Exception as e:
print(f"β Text extraction error: {e}")
return ""
# Split into chunks with overlap for better context
def split_into_chunks(text, max_tokens=300, overlap=50):
sentences = text.split('.')
chunks, current = [], ''
for sentence in sentences:
sentence = sentence.strip() + '.'
if len(current) + len(sentence) < max_tokens:
current += sentence
else:
chunks.append(current.strip())
# Keep some overlap for context continuity
words = current.split()
if len(words) > overlap:
current = ' '.join(words[-overlap:]) + ' ' + sentence
else:
current = sentence
if current:
chunks.append(current.strip())
return chunks
# Setup FAISS index for semantic search
def setup_faiss(chunks):
try:
embeddings = embedding_model.encode(chunks)
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(embeddings)
return index, embeddings, chunks
except Exception as e:
print(f"β FAISS setup error: {e}")
raise
# Get answer using QA pipeline
def answer_with_qa_pipeline(chunks, question):
try:
# Join relevant chunks for context
context = " ".join(chunks[:5]) # Using first 5 chunks for simplicity
result = qa_pipeline(question=question, context=context)
return result['answer']
except Exception as e:
print(f"β QA pipeline error: {e}")
return "Could not generate answer with QA pipeline."
# Get answer using generation model
def answer_with_generation(index, embeddings, chunks, question):
try:
# Get embeddings for question
q_embedding = embedding_model.encode([question])
# Search in FAISS index for most relevant chunks
_, top_k_indices = index.search(q_embedding, k=3)
relevant_chunks = [chunks[i] for i in top_k_indices[0]]
context = " ".join(relevant_chunks)
# Create a clear prompt
prompt = f"Answer the following question based on this information:\n\nInformation: {context}\n\nQuestion: {question}\n\nDetailed answer:"
# Generate response
inputs = gen_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(device)
output = gen_model.generate(
**inputs,
max_new_tokens=300,
temperature=0.7,
top_p=0.9,
do_sample=True,
num_beams=3,
no_repeat_ngram_size=2
)
answer = gen_tokenizer.decode(output[0], skip_special_tokens=True)
# Extract the answer part
if "Detailed answer:" in answer:
return answer.split("Detailed answer:")[-1].strip()
return answer
except Exception as e:
print(f"β Generation error: {e}")
return "Could not generate answer."
# Process PDF and answer question
def process_pdf_and_answer(pdf_bytes, question):
try:
# Extract text from PDF
text = extract_text(pdf_bytes)
if not text:
return "Could not extract text from the PDF."
# Split into chunks
chunks = split_into_chunks(text)
if not chunks:
return "Could not process the PDF content."
# Try QA pipeline first
print("Attempting to answer with QA pipeline...")
qa_answer = answer_with_qa_pipeline(chunks, question)
# If QA answer is too short or empty, try generation approach
if len(qa_answer) < 20:
print("QA answer too short, trying generation approach...")
index, embeddings, chunks = setup_faiss(chunks)
gen_answer = answer_with_generation(index, embeddings, chunks, question)
return gen_answer
return qa_answer
except Exception as e:
print(f"β Processing error: {e}")
return f"An error occurred: {str(e)}"
# API Endpoints
@app.route("/health", methods=["GET"])
def health_check():
"""Health check endpoint"""
return jsonify({"status": "healthy"})
@app.route("/api/ask", methods=["POST"])
def ask_question():
"""Endpoint for asking questions about a PDF"""
try:
if 'file' not in request.files:
return jsonify({"error": "No file provided"}), 400
file = request.files['file']
if not file or file.filename == '':
return jsonify({"error": "Invalid file"}), 400
if 'question' not in request.form:
return jsonify({"error": "No question provided"}), 400
question = request.form['question']
pdf_bytes = file.read()
answer = process_pdf_and_answer(pdf_bytes, question)
return jsonify({
"answer": answer,
"success": True
})
except Exception as e:
print(f"β API error: {e}")
return jsonify({"error": str(e), "success": False}), 500
if __name__ == "__main__":
# Get port from environment variable or use 7860 (Hugging Face Spaces default)
port = int(os.environ.get("PORT", 7860))
app.run(host="0.0.0.0", port=port)
|