File size: 7,682 Bytes
c971d0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
# app.py
import os
import io
import tempfile
import fitz  # PyMuPDF
import pytesseract
from pdf2image import convert_from_bytes, convert_from_path
import numpy as np
import faiss
import torch
from flask import Flask, request, jsonify
from flask_cors import CORS
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer

app = Flask(__name__)
CORS(app)  # Enable CORS for cross-origin requests

# Device setup
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Load models at startup (only once)
try:
    print("Loading models...")
    # Embedding model for semantic search
    embedding_model = SentenceTransformer("all-MiniLM-L6-v2").to(device)
    
    # QA pipeline for direct question answering
    qa_pipeline = pipeline(
        "question-answering",
        model="distilbert-base-cased-distilled-squad",
        tokenizer="distilbert-base-cased",
        device=0 if device == "cuda" else -1
    )
    
    # Generation model for more detailed responses
    gen_model_name = "distilgpt2"  # Lightweight model suitable for free tier
    gen_tokenizer = AutoTokenizer.from_pretrained(gen_model_name)
    gen_model = AutoModelForCausalLM.from_pretrained(gen_model_name).to(device)
    
    # Ensure pad token is set for the tokenizer
    if gen_tokenizer.pad_token is None:
        gen_tokenizer.pad_token = gen_tokenizer.eos_token
        gen_model.config.pad_token_id = gen_model.config.eos_token_id
    
    print("βœ… Models loaded successfully")
except Exception as e:
    print(f"❌ Error loading models: {e}")
    raise

# OCR fallback for scanned PDFs
def ocr_pdf(pdf_bytes):
    try:
        images = convert_from_bytes(pdf_bytes)
        text = ""
        for img in images:
            text += pytesseract.image_to_string(img)
        return text
    except Exception as e:
        print(f"❌ OCR error: {e}")
        return ""

# Text extraction from standard PDFs
def extract_text(pdf_bytes):
    try:
        doc = fitz.open(stream=io.BytesIO(pdf_bytes), filetype="pdf")
        text = ""
        for page in doc:
            text += page.get_text()
        if len(text.strip()) < 50:
            print("⚠️ Not enough text, using OCR fallback...")
            text = ocr_pdf(pdf_bytes)
        print("βœ… Text extraction complete")
        return text
    except Exception as e:
        print(f"❌ Text extraction error: {e}")
        return ""

# Split into chunks with overlap for better context
def split_into_chunks(text, max_tokens=300, overlap=50):
    sentences = text.split('.')
    chunks, current = [], ''
    for sentence in sentences:
        sentence = sentence.strip() + '.'
        if len(current) + len(sentence) < max_tokens:
            current += sentence
        else:
            chunks.append(current.strip())
            # Keep some overlap for context continuity
            words = current.split()
            if len(words) > overlap:
                current = ' '.join(words[-overlap:]) + ' ' + sentence
            else:
                current = sentence
    if current:
        chunks.append(current.strip())
    return chunks

# Setup FAISS index for semantic search
def setup_faiss(chunks):
    try:
        embeddings = embedding_model.encode(chunks)
        dimension = embeddings.shape[1]
        
        index = faiss.IndexFlatL2(dimension)
        index.add(embeddings)
        return index, embeddings, chunks
    except Exception as e:
        print(f"❌ FAISS setup error: {e}")
        raise

# Get answer using QA pipeline
def answer_with_qa_pipeline(chunks, question):
    try:
        # Join relevant chunks for context
        context = " ".join(chunks[:5])  # Using first 5 chunks for simplicity
        
        result = qa_pipeline(question=question, context=context)
        return result['answer']
    except Exception as e:
        print(f"❌ QA pipeline error: {e}")
        return "Could not generate answer with QA pipeline."

# Get answer using generation model
def answer_with_generation(index, embeddings, chunks, question):
    try:
        # Get embeddings for question
        q_embedding = embedding_model.encode([question])
        
        # Search in FAISS index for most relevant chunks
        _, top_k_indices = index.search(q_embedding, k=3)
        relevant_chunks = [chunks[i] for i in top_k_indices[0]]
        context = " ".join(relevant_chunks)
        
        # Create a clear prompt
        prompt = f"Answer the following question based on this information:\n\nInformation: {context}\n\nQuestion: {question}\n\nDetailed answer:"
        
        # Generate response
        inputs = gen_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(device)
        output = gen_model.generate(
            **inputs,
            max_new_tokens=300,
            temperature=0.7,
            top_p=0.9,
            do_sample=True,
            num_beams=3,
            no_repeat_ngram_size=2
        )
        answer = gen_tokenizer.decode(output[0], skip_special_tokens=True)
        
        # Extract the answer part
        if "Detailed answer:" in answer:
            return answer.split("Detailed answer:")[-1].strip()
        return answer
    except Exception as e:
        print(f"❌ Generation error: {e}")
        return "Could not generate answer."

# Process PDF and answer question
def process_pdf_and_answer(pdf_bytes, question):
    try:
        # Extract text from PDF
        text = extract_text(pdf_bytes)
        if not text:
            return "Could not extract text from the PDF."
        
        # Split into chunks
        chunks = split_into_chunks(text)
        if not chunks:
            return "Could not process the PDF content."
        
        # Try QA pipeline first
        print("Attempting to answer with QA pipeline...")
        qa_answer = answer_with_qa_pipeline(chunks, question)
        
        # If QA answer is too short or empty, try generation approach
        if len(qa_answer) < 20:
            print("QA answer too short, trying generation approach...")
            index, embeddings, chunks = setup_faiss(chunks)
            gen_answer = answer_with_generation(index, embeddings, chunks, question)
            return gen_answer
        
        return qa_answer
    except Exception as e:
        print(f"❌ Processing error: {e}")
        return f"An error occurred: {str(e)}"

# API Endpoints

@app.route("/health", methods=["GET"])
def health_check():
    """Health check endpoint"""
    return jsonify({"status": "healthy"})

@app.route("/api/ask", methods=["POST"])
def ask_question():
    """Endpoint for asking questions about a PDF"""
    try:
        if 'file' not in request.files:
            return jsonify({"error": "No file provided"}), 400
        
        file = request.files['file']
        if not file or file.filename == '':
            return jsonify({"error": "Invalid file"}), 400
        
        if 'question' not in request.form:
            return jsonify({"error": "No question provided"}), 400
        
        question = request.form['question']
        pdf_bytes = file.read()
        
        answer = process_pdf_and_answer(pdf_bytes, question)
        
        return jsonify({
            "answer": answer,
            "success": True
        })
    except Exception as e:
        print(f"❌ API error: {e}")
        return jsonify({"error": str(e), "success": False}), 500

if __name__ == "__main__":
    # Get port from environment variable or use 7860 (Hugging Face Spaces default)
    port = int(os.environ.get("PORT", 7860))
    app.run(host="0.0.0.0", port=port)