priyanshu23456 commited on
Commit
d47a566
Β·
verified Β·
1 Parent(s): 63da8c0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +121 -74
app.py CHANGED
@@ -1,90 +1,137 @@
1
  import os
2
- from flask import Flask, request, jsonify
3
- from flask_cors import CORS
4
- import io
5
  import fitz # PyMuPDF
 
 
6
  import torch
7
- from transformers import pipeline
 
 
 
 
8
 
9
- app = Flask(__name__)
10
- CORS(app)
11
-
12
- # Device setup
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
- print(f"Using device: {device}")
15
 
16
- # Load a simpler QA model
17
- try:
18
- print("Loading QA model...")
19
- qa_model = pipeline(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  "question-answering",
21
  model="distilbert-base-cased-distilled-squad",
 
22
  device=0 if device == "cuda" else -1
23
  )
24
- print("βœ… Model loaded successfully")
25
- except Exception as e:
26
- print(f"❌ Error loading model: {e}")
27
- raise
28
-
29
- # Text extraction from PDFs
30
- def extract_text(pdf_bytes):
31
  try:
32
- doc = fitz.open(stream=io.BytesIO(pdf_bytes), filetype="pdf")
33
- text = ""
34
- for page in doc:
35
- text += page.get_text()
36
- print("βœ… Text extraction complete")
37
- return text
38
- except Exception as e:
39
- print(f"❌ Text extraction error: {e}")
40
- return ""
41
-
42
- # Process PDF and answer question
43
- def process_pdf_and_answer(pdf_bytes, question):
44
- try:
45
- # Extract text from PDF
46
- text = extract_text(pdf_bytes)
47
- if not text:
48
- return "Could not extract text from the PDF."
49
-
50
- # Use QA model directly (limiting context size for memory constraints)
51
- result = qa_model(question=question, context=text[:5000])
52
  return result['answer']
53
- except Exception as e:
54
- print(f"❌ Processing error: {e}")
55
- return f"An error occurred: {str(e)}"
 
 
 
 
 
 
 
 
 
56
 
57
- # API Endpoints
58
- @app.route("/health", methods=["GET"])
59
- def health_check():
60
- return jsonify({"status": "healthy"})
 
 
 
 
61
 
62
- @app.route("/api/ask", methods=["POST"])
63
- def ask_question():
64
  try:
65
- if 'file' not in request.files:
66
- return jsonify({"error": "No file provided"}), 400
67
-
68
- file = request.files['file']
69
- if not file or file.filename == '':
70
- return jsonify({"error": "Invalid file"}), 400
71
-
72
- if 'question' not in request.form:
73
- return jsonify({"error": "No question provided"}), 400
74
-
75
- question = request.form['question']
76
- pdf_bytes = file.read()
77
-
78
- answer = process_pdf_and_answer(pdf_bytes, question)
79
-
80
- return jsonify({
81
- "answer": answer,
82
- "success": True
83
- })
84
- except Exception as e:
85
- print(f"❌ API error: {e}")
86
- return jsonify({"error": str(e), "success": False}), 500
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
- if __name__ == "__main__":
89
- port = int(os.environ.get("PORT", 7860))
90
- app.run(host="0.0.0.0", port=port)
 
1
  import os
 
 
 
2
  import fitz # PyMuPDF
3
+ import pytesseract
4
+ from pdf2image import convert_from_path
5
  import torch
6
+ import faiss
7
+ import numpy as np
8
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
9
+ from sentence_transformers import SentenceTransformer
10
+ import gradio as gr
11
 
12
+ # βœ… Device setup
 
 
 
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
14
 
15
+ # βœ… OCR fallback
16
+ def ocr_pdf(pdf_path):
17
+ images = convert_from_path(pdf_path)
18
+ text = ""
19
+ for img in images:
20
+ text += pytesseract.image_to_string(img)
21
+ return text
22
+
23
+ # βœ… Text extraction
24
+ def extract_text(pdf_path):
25
+ doc = fitz.open(pdf_path)
26
+ text = ""
27
+ for page in doc:
28
+ text += page.get_text()
29
+ if len(text.strip()) < 50:
30
+ print("⚠️ Not enough text, using OCR fallback...")
31
+ text = ocr_pdf(pdf_path)
32
+ print("βœ… Text extraction complete")
33
+ return text
34
+
35
+ # βœ… Chunking
36
+ def split_into_chunks(text, max_tokens=300, overlap=50):
37
+ sentences = text.split('.')
38
+ chunks, current = [], ''
39
+ for sentence in sentences:
40
+ sentence = sentence.strip() + '.'
41
+ if len(current) + len(sentence) < max_tokens:
42
+ current += sentence
43
+ else:
44
+ chunks.append(current.strip())
45
+ words = current.split()
46
+ if len(words) > overlap:
47
+ current = ' '.join(words[-overlap:]) + ' ' + sentence
48
+ else:
49
+ current = sentence
50
+ if current:
51
+ chunks.append(current.strip())
52
+ return chunks
53
+
54
+ # βœ… FAISS setup
55
+ def setup_faiss(chunks):
56
+ embedder = SentenceTransformer("all-MiniLM-L6-v2")
57
+ embeddings = embedder.encode(chunks)
58
+ dimension = embeddings.shape[1]
59
+ index = faiss.IndexFlatL2(dimension)
60
+ index.add(embeddings)
61
+ return index, embeddings, chunks
62
+
63
+ # βœ… QA method
64
+ def answer_with_qa_pipeline(chunks, question):
65
+ qa_pipeline = pipeline(
66
  "question-answering",
67
  model="distilbert-base-cased-distilled-squad",
68
+ tokenizer="distilbert-base-cased",
69
  device=0 if device == "cuda" else -1
70
  )
71
+ context = " ".join(chunks[:5])
 
 
 
 
 
 
72
  try:
73
+ result = qa_pipeline(question=question, context=context)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  return result['answer']
75
+ except:
76
+ return "Could not answer with QA pipeline."
77
+
78
+ # βœ… Generation method
79
+ def answer_with_generation(index, embeddings, chunks, question):
80
+ model_name = "distilgpt2"
81
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
82
+ model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
83
+
84
+ if tokenizer.pad_token is None:
85
+ tokenizer.pad_token = tokenizer.eos_token
86
+ model.config.pad_token_id = model.config.eos_token_id
87
 
88
+ embedder = SentenceTransformer("all-MiniLM-L6-v2")
89
+ q_embedding = embedder.encode([question])
90
+ _, top_k_indices = index.search(q_embedding, k=3)
91
+ relevant_chunks = [chunks[i] for i in top_k_indices[0]]
92
+ context = " ".join(relevant_chunks)
93
+
94
+ prompt = f"Answer the following question based on this information:\n\nInformation: {context}\n\nQuestion: {question}\n\nDetailed answer:"
95
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(device)
96
 
 
 
97
  try:
98
+ output = model.generate(
99
+ **inputs,
100
+ max_new_tokens=300,
101
+ temperature=0.7,
102
+ top_p=0.9,
103
+ do_sample=True,
104
+ num_beams=3,
105
+ no_repeat_ngram_size=2
106
+ )
107
+ answer = tokenizer.decode(output[0], skip_special_tokens=True)
108
+ if "Detailed answer:" in answer:
109
+ return answer.split("Detailed answer:")[-1].strip()
110
+ return answer
111
+ except:
112
+ return "Could not generate answer."
113
+
114
+ # βœ… Main logic
115
+ def process_pdf(file, question):
116
+ pdf_path = file.name
117
+ text = extract_text(pdf_path)
118
+ chunks = split_into_chunks(text)
119
+ qa_answer = answer_with_qa_pipeline(chunks, question)
120
+ if len(qa_answer) < 20:
121
+ index, embeddings, chunks = setup_faiss(chunks)
122
+ return answer_with_generation(index, embeddings, chunks, question)
123
+ return qa_answer
124
+
125
+ # βœ… Gradio UI
126
+ iface = gr.Interface(
127
+ fn=process_pdf,
128
+ inputs=[
129
+ gr.File(label="Upload PDF"),
130
+ gr.Textbox(label="Ask a question", placeholder="What is this PDF about?")
131
+ ],
132
+ outputs="text",
133
+ title="πŸ“„ PDF Chat Assistant",
134
+ description="Upload a PDF and ask anything about its content, even if it has scanned images!"
135
+ )
136
 
137
+ iface.launch()