Spaces:
Sleeping
Sleeping
IZERE HIRWA Roger
commited on
Commit
Β·
b7d66ac
1
Parent(s):
1cc833a
olm
Browse files- app.py +139 -5
- requirements.txt +4 -1
- static/index.html +8 -3
- static/script.js +7 -3
app.py
CHANGED
@@ -17,6 +17,11 @@ from datetime import datetime, timedelta
|
|
17 |
import jwt
|
18 |
import sqlite3
|
19 |
import tempfile
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
app = Flask(__name__)
|
22 |
app.config['SECRET_KEY'] = 'your-secret-key-change-this-in-production'
|
@@ -109,6 +114,29 @@ except Exception as e:
|
|
109 |
clip_model = None
|
110 |
preprocess = None
|
111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
# Helper functions
|
113 |
def save_index():
|
114 |
try:
|
@@ -151,7 +179,95 @@ def image_from_pdf(pdf_bytes):
|
|
151 |
print(f"β PDF conversion error: {e}")
|
152 |
return None
|
153 |
|
154 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
try:
|
156 |
if image.mode != 'RGB':
|
157 |
image = image.convert('RGB')
|
@@ -161,6 +277,13 @@ def extract_text(image):
|
|
161 |
except Exception as e:
|
162 |
return f"β OCR error: {str(e)}"
|
163 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
def get_clip_embedding(image):
|
165 |
try:
|
166 |
if clip_model is None:
|
@@ -297,10 +420,10 @@ def classify_document():
|
|
297 |
sim = float(1 - D[0][i])
|
298 |
matches.append({"category": labels[I[0][i]], "similarity": round(sim, 3)})
|
299 |
|
300 |
-
# Save classified document to SQLite
|
301 |
if similarity >= confidence_threshold:
|
302 |
saved_filename = save_uploaded_file(file_content, file.filename)
|
303 |
-
ocr_text = extract_text(image)
|
304 |
|
305 |
document_id = str(uuid.uuid4())
|
306 |
conn = sqlite3.connect(DATABASE_PATH)
|
@@ -320,7 +443,8 @@ def classify_document():
|
|
320 |
"confidence": "high",
|
321 |
"matches": matches,
|
322 |
"document_saved": True,
|
323 |
-
"document_id": document_id
|
|
|
324 |
})
|
325 |
else:
|
326 |
return jsonify({
|
@@ -452,8 +576,18 @@ def ocr_document():
|
|
452 |
if image is None:
|
453 |
return jsonify({"error": "Failed to process image"}), 400
|
454 |
|
|
|
455 |
text = extract_text(image)
|
456 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
457 |
except Exception as e:
|
458 |
return jsonify({"error": str(e)}), 500
|
459 |
|
|
|
17 |
import jwt
|
18 |
import sqlite3
|
19 |
import tempfile
|
20 |
+
import base64
|
21 |
+
from io import BytesIO
|
22 |
+
from transformers import AutoModelForImageTextToText, AutoProcessor, AutoTokenizer, TextIteratorStreamer
|
23 |
+
from threading import Thread
|
24 |
+
import time
|
25 |
|
26 |
app = Flask(__name__)
|
27 |
app.config['SECRET_KEY'] = 'your-secret-key-change-this-in-production'
|
|
|
114 |
clip_model = None
|
115 |
preprocess = None
|
116 |
|
117 |
+
# Initialize Nanonets OCR model
|
118 |
+
ocr_model = None
|
119 |
+
ocr_processor = None
|
120 |
+
ocr_tokenizer = None
|
121 |
+
|
122 |
+
try:
|
123 |
+
model_path = "nanonets/Nanonets-OCR-s"
|
124 |
+
print("Loading Nanonets OCR model...")
|
125 |
+
ocr_model = AutoModelForImageTextToText.from_pretrained(
|
126 |
+
model_path,
|
127 |
+
torch_dtype="auto",
|
128 |
+
device_map="auto",
|
129 |
+
trust_remote_code=True
|
130 |
+
)
|
131 |
+
ocr_model.eval()
|
132 |
+
|
133 |
+
ocr_processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
|
134 |
+
ocr_tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
135 |
+
print("β
Nanonets OCR model loaded successfully!")
|
136 |
+
except Exception as e:
|
137 |
+
print(f"β Failed to load Nanonets OCR model: {e}")
|
138 |
+
print("π Falling back to pytesseract for OCR")
|
139 |
+
|
140 |
# Helper functions
|
141 |
def save_index():
|
142 |
try:
|
|
|
179 |
print(f"β PDF conversion error: {e}")
|
180 |
return None
|
181 |
|
182 |
+
def process_tags(content: str) -> str:
|
183 |
+
"""Process special tags from Nanonets OCR output"""
|
184 |
+
content = content.replace("<img>", "<img>")
|
185 |
+
content = content.replace("</img>", "</img>")
|
186 |
+
content = content.replace("<watermark>", "<watermark>")
|
187 |
+
content = content.replace("</watermark>", "</watermark>")
|
188 |
+
content = content.replace("<page_number>", "<page_number>")
|
189 |
+
content = content.replace("</page_number>", "</page_number>")
|
190 |
+
content = content.replace("<signature>", "<signature>")
|
191 |
+
content = content.replace("</signature>", "</signature>")
|
192 |
+
return content
|
193 |
+
|
194 |
+
def encode_image(image: Image) -> str:
|
195 |
+
"""Encode image to base64 for Nanonets OCR"""
|
196 |
+
buffered = BytesIO()
|
197 |
+
image.save(buffered, format="JPEG")
|
198 |
+
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
199 |
+
return img_str
|
200 |
+
|
201 |
+
def nanonets_ocr_extract(image):
|
202 |
+
"""Extract text using Nanonets OCR model"""
|
203 |
+
try:
|
204 |
+
if ocr_model is None or ocr_processor is None or ocr_tokenizer is None:
|
205 |
+
# Fallback to pytesseract
|
206 |
+
return extract_text_pytesseract(image)
|
207 |
+
|
208 |
+
if image.mode != 'RGB':
|
209 |
+
image = image.convert('RGB')
|
210 |
+
|
211 |
+
# Resize image for optimal processing
|
212 |
+
image = image.resize((2048, 2048))
|
213 |
+
|
214 |
+
# Prepare prompt for OCR extraction
|
215 |
+
user_prompt = """Extract the text from the above document as if you were reading it naturally. Return the tables in html format. Return the equations in LaTeX representation. Watermarks should be wrapped in brackets. Ex: <watermark>OFFICIAL COPY</watermark>. Page numbers should be wrapped in brackets. Ex: <page_number>14</page_number> or <page_number>9/22</page_number>. Prefer using β and β for check boxes."""
|
216 |
+
|
217 |
+
# Format messages for the model
|
218 |
+
formatted_messages = [
|
219 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
220 |
+
{"role": "user", "content": [
|
221 |
+
{"type": "image", "image": image},
|
222 |
+
{"type": "text", "text": user_prompt},
|
223 |
+
]},
|
224 |
+
]
|
225 |
+
|
226 |
+
# Apply chat template
|
227 |
+
text = ocr_processor.apply_chat_template(
|
228 |
+
formatted_messages,
|
229 |
+
tokenize=False,
|
230 |
+
add_generation_prompt=True
|
231 |
+
)
|
232 |
+
|
233 |
+
# Process inputs
|
234 |
+
inputs = ocr_processor(
|
235 |
+
text=[text],
|
236 |
+
images=[image],
|
237 |
+
padding=True,
|
238 |
+
return_tensors="pt"
|
239 |
+
)
|
240 |
+
|
241 |
+
# Move inputs to model device
|
242 |
+
inputs = {k: v.to(ocr_model.device) if hasattr(v, 'to') else v for k, v in inputs.items()}
|
243 |
+
|
244 |
+
# Generate text
|
245 |
+
with torch.no_grad():
|
246 |
+
generated_ids = ocr_model.generate(
|
247 |
+
**inputs,
|
248 |
+
max_new_tokens=4096,
|
249 |
+
do_sample=False,
|
250 |
+
pad_token_id=ocr_tokenizer.eos_token_id,
|
251 |
+
)
|
252 |
+
|
253 |
+
# Decode the generated text
|
254 |
+
generated_text = ocr_tokenizer.decode(
|
255 |
+
generated_ids[0][inputs['input_ids'].shape[1]:],
|
256 |
+
skip_special_tokens=True
|
257 |
+
)
|
258 |
+
|
259 |
+
# Process special tags
|
260 |
+
processed_text = process_tags(generated_text)
|
261 |
+
|
262 |
+
return processed_text.strip() if processed_text.strip() else "β No text detected"
|
263 |
+
|
264 |
+
except Exception as e:
|
265 |
+
print(f"β Nanonets OCR error: {e}")
|
266 |
+
# Fallback to pytesseract
|
267 |
+
return extract_text_pytesseract(image)
|
268 |
+
|
269 |
+
def extract_text_pytesseract(image):
|
270 |
+
"""Fallback OCR using pytesseract"""
|
271 |
try:
|
272 |
if image.mode != 'RGB':
|
273 |
image = image.convert('RGB')
|
|
|
277 |
except Exception as e:
|
278 |
return f"β OCR error: {str(e)}"
|
279 |
|
280 |
+
def extract_text(image):
|
281 |
+
"""Main OCR function - tries Nanonets first, falls back to pytesseract"""
|
282 |
+
if ocr_model is not None:
|
283 |
+
return nanonets_ocr_extract(image)
|
284 |
+
else:
|
285 |
+
return extract_text_pytesseract(image)
|
286 |
+
|
287 |
def get_clip_embedding(image):
|
288 |
try:
|
289 |
if clip_model is None:
|
|
|
420 |
sim = float(1 - D[0][i])
|
421 |
matches.append({"category": labels[I[0][i]], "similarity": round(sim, 3)})
|
422 |
|
423 |
+
# Save classified document to SQLite with enhanced OCR
|
424 |
if similarity >= confidence_threshold:
|
425 |
saved_filename = save_uploaded_file(file_content, file.filename)
|
426 |
+
ocr_text = extract_text(image) # Now uses Nanonets OCR
|
427 |
|
428 |
document_id = str(uuid.uuid4())
|
429 |
conn = sqlite3.connect(DATABASE_PATH)
|
|
|
443 |
"confidence": "high",
|
444 |
"matches": matches,
|
445 |
"document_saved": True,
|
446 |
+
"document_id": document_id,
|
447 |
+
"ocr_preview": ocr_text[:200] + "..." if len(ocr_text) > 200 else ocr_text
|
448 |
})
|
449 |
else:
|
450 |
return jsonify({
|
|
|
576 |
if image is None:
|
577 |
return jsonify({"error": "Failed to process image"}), 400
|
578 |
|
579 |
+
# Use enhanced Nanonets OCR
|
580 |
text = extract_text(image)
|
581 |
+
|
582 |
+
# Determine OCR method used
|
583 |
+
ocr_method = "Nanonets OCR-s" if ocr_model is not None else "Pytesseract"
|
584 |
+
|
585 |
+
return jsonify({
|
586 |
+
"text": text,
|
587 |
+
"status": "success",
|
588 |
+
"ocr_method": ocr_method,
|
589 |
+
"enhanced_features": ocr_model is not None
|
590 |
+
})
|
591 |
except Exception as e:
|
592 |
return jsonify({"error": str(e)}), 500
|
593 |
|
requirements.txt
CHANGED
@@ -8,4 +8,7 @@ torchvision
|
|
8 |
Pillow
|
9 |
PyJWT
|
10 |
git+https://github.com/openai/CLIP.git
|
11 |
-
poppler-utils
|
|
|
|
|
|
|
|
8 |
Pillow
|
9 |
PyJWT
|
10 |
git+https://github.com/openai/CLIP.git
|
11 |
+
poppler-utils
|
12 |
+
transformers
|
13 |
+
accelerate
|
14 |
+
spaces
|
static/index.html
CHANGED
@@ -255,10 +255,14 @@
|
|
255 |
<div id="ocr" class="tab-content">
|
256 |
<div class="card">
|
257 |
<div class="card-header bg-warning text-dark">
|
258 |
-
<h4><i class="fas fa-eye me-2"></i>OCR Text Extraction</h4>
|
259 |
</div>
|
260 |
<div class="card-body">
|
261 |
-
<
|
|
|
|
|
|
|
|
|
262 |
|
263 |
<form id="ocrForm" class="row g-3">
|
264 |
<div class="col-12">
|
@@ -266,12 +270,13 @@
|
|
266 |
<div class="file-upload border rounded p-4 text-center" id="ocrUpload">
|
267 |
<i class="fas fa-file-alt fa-3x text-warning mb-3"></i>
|
268 |
<p class="mb-0">Click to select or drag & drop files here</p>
|
|
|
269 |
<input type="file" id="ocrFile" accept="image/*,.pdf" class="d-none">
|
270 |
</div>
|
271 |
</div>
|
272 |
<div class="col-12">
|
273 |
<button type="submit" class="btn btn-warning">
|
274 |
-
<i class="fas fa-
|
275 |
</button>
|
276 |
</div>
|
277 |
</form>
|
|
|
255 |
<div id="ocr" class="tab-content">
|
256 |
<div class="card">
|
257 |
<div class="card-header bg-warning text-dark">
|
258 |
+
<h4><i class="fas fa-eye me-2"></i>Advanced OCR Text Extraction</h4>
|
259 |
</div>
|
260 |
<div class="card-body">
|
261 |
+
<div class="alert alert-info" role="alert">
|
262 |
+
<i class="fas fa-info-circle me-2"></i>
|
263 |
+
<strong>Enhanced OCR Features:</strong> Our advanced Nanonets OCR-s model supports table extraction (HTML), LaTeX equations, watermark detection, signature recognition, and checkbox handling.
|
264 |
+
</div>
|
265 |
+
<p class="text-muted">Extract text from documents using advanced Optical Character Recognition with AI-powered document understanding.</p>
|
266 |
|
267 |
<form id="ocrForm" class="row g-3">
|
268 |
<div class="col-12">
|
|
|
270 |
<div class="file-upload border rounded p-4 text-center" id="ocrUpload">
|
271 |
<i class="fas fa-file-alt fa-3x text-warning mb-3"></i>
|
272 |
<p class="mb-0">Click to select or drag & drop files here</p>
|
273 |
+
<small class="text-muted">Supports: PDF, JPEG, PNG, TIFF</small>
|
274 |
<input type="file" id="ocrFile" accept="image/*,.pdf" class="d-none">
|
275 |
</div>
|
276 |
</div>
|
277 |
<div class="col-12">
|
278 |
<button type="submit" class="btn btn-warning">
|
279 |
+
<i class="fas fa-robot me-2"></i>Extract Text with AI OCR
|
280 |
</button>
|
281 |
</div>
|
282 |
</form>
|
static/script.js
CHANGED
@@ -625,7 +625,7 @@ document.getElementById('ocrForm').addEventListener('submit', async (e) => {
|
|
625 |
const formData = new FormData();
|
626 |
formData.append('file', fileInput.files[0]);
|
627 |
|
628 |
-
showResult(resultDiv, '<div class="loading"></div> Extracting text...', 'info');
|
629 |
|
630 |
try {
|
631 |
const response = await authenticatedFetch('/api/ocr', {
|
@@ -636,9 +636,13 @@ document.getElementById('ocrForm').addEventListener('submit', async (e) => {
|
|
636 |
const result = await response.json();
|
637 |
|
638 |
if (response.ok) {
|
639 |
-
|
|
|
|
|
|
|
|
|
640 |
} else {
|
641 |
-
showResult(resultDiv, result.detail, 'error');
|
642 |
}
|
643 |
} catch (error) {
|
644 |
showResult(resultDiv, 'OCR failed: ' + error.message, 'error');
|
|
|
625 |
const formData = new FormData();
|
626 |
formData.append('file', fileInput.files[0]);
|
627 |
|
628 |
+
showResult(resultDiv, '<div class="loading"></div> Extracting text with advanced OCR...', 'info');
|
629 |
|
630 |
try {
|
631 |
const response = await authenticatedFetch('/api/ocr', {
|
|
|
636 |
const result = await response.json();
|
637 |
|
638 |
if (response.ok) {
|
639 |
+
const ocrInfo = result.enhanced_features ?
|
640 |
+
`π€ Processed with ${result.ocr_method} (Enhanced Features: Tables, LaTeX, Watermarks)\n\n` :
|
641 |
+
`π Processed with ${result.ocr_method}\n\n`;
|
642 |
+
|
643 |
+
showResult(resultDiv, ocrInfo + result.text, 'success');
|
644 |
} else {
|
645 |
+
showResult(resultDiv, result.error || result.detail, 'error');
|
646 |
}
|
647 |
} catch (error) {
|
648 |
showResult(resultDiv, 'OCR failed: ' + error.message, 'error');
|