Spaces:
Sleeping
Sleeping
from flask import Flask, request, jsonify, send_from_directory | |
from werkzeug.utils import secure_filename | |
from werkzeug.security import generate_password_hash, check_password_hash | |
import pytesseract | |
from PIL import Image | |
import numpy as np | |
import faiss | |
import os | |
import pickle | |
from pdf2image import convert_from_bytes | |
import torch | |
import clip | |
import io | |
import json | |
import uuid | |
from datetime import datetime, timedelta | |
import jwt | |
import sqlite3 | |
import tempfile | |
import base64 | |
from io import BytesIO | |
from transformers import AutoModelForImageTextToText, AutoProcessor, AutoTokenizer, TextIteratorStreamer | |
from threading import Thread | |
import time | |
app = Flask(__name__) | |
app.config['SECRET_KEY'] = 'your-secret-key-change-this-in-production' | |
# Security configuration | |
SECRET_KEY = "your-secret-key-change-this-in-production" | |
ALGORITHM = "HS256" | |
ACCESS_TOKEN_EXPIRE_MINUTES = 30 | |
# Set CLIP cache to writable directory | |
os.environ['CLIP_CACHE'] = '/app/clip_cache' | |
os.makedirs('/app/clip_cache', exist_ok=True) | |
# Directories | |
INDEX_PATH = "data/index.faiss" | |
LABELS_PATH = "data/labels.pkl" | |
DATABASE_PATH = "data/documents.db" | |
UPLOADS_DIR = "data/uploads" | |
os.makedirs("data", exist_ok=True) | |
os.makedirs("static", exist_ok=True) | |
os.makedirs(UPLOADS_DIR, exist_ok=True) | |
# Initialize database | |
def init_db(): | |
conn = sqlite3.connect(DATABASE_PATH) | |
cursor = conn.cursor() | |
# Users table | |
cursor.execute(''' | |
CREATE TABLE IF NOT EXISTS users ( | |
id INTEGER PRIMARY KEY AUTOINCREMENT, | |
username TEXT UNIQUE NOT NULL, | |
password_hash TEXT NOT NULL, | |
is_active BOOLEAN DEFAULT TRUE | |
) | |
''') | |
# Documents table | |
cursor.execute(''' | |
CREATE TABLE IF NOT EXISTS documents ( | |
id TEXT PRIMARY KEY, | |
filename TEXT NOT NULL, | |
original_filename TEXT NOT NULL, | |
category TEXT NOT NULL, | |
similarity REAL NOT NULL, | |
ocr_text TEXT, | |
upload_date TEXT NOT NULL, | |
file_path TEXT NOT NULL | |
) | |
''') | |
# Insert default admin user if not exists | |
cursor.execute('SELECT * FROM users WHERE username = ?', ('admin',)) | |
if not cursor.fetchone(): | |
admin_hash = generate_password_hash('admin123') | |
cursor.execute('INSERT INTO users (username, password_hash) VALUES (?, ?)', | |
('admin', admin_hash)) | |
conn.commit() | |
conn.close() | |
init_db() | |
# Initialize index and labels | |
index = faiss.IndexFlatL2(512) | |
labels = [] | |
if os.path.exists(INDEX_PATH) and os.path.exists(LABELS_PATH): | |
try: | |
index = faiss.read_index(INDEX_PATH) | |
with open(LABELS_PATH, "rb") as f: | |
labels = pickle.load(f) | |
print(f"β Loaded existing index with {len(labels)} labels") | |
except Exception as e: | |
print(f"β οΈ Failed to load existing index: {e}") | |
if os.path.exists(INDEX_PATH): | |
os.remove(INDEX_PATH) | |
if os.path.exists(LABELS_PATH): | |
os.remove(LABELS_PATH) | |
# Initialize CLIP model with custom cache | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
try: | |
clip_model, preprocess = clip.load("ViT-B/32", device=device, download_root='/app/clip_cache') | |
print("β CLIP model loaded successfully") | |
except Exception as e: | |
print(f"β Failed to load CLIP model: {e}") | |
# Fallback initialization | |
clip_model = None | |
preprocess = None | |
# Initialize Nanonets OCR model | |
ocr_model = None | |
ocr_processor = None | |
ocr_tokenizer = None | |
try: | |
model_path = "nanonets/Nanonets-OCR-s" | |
print("Loading Nanonets OCR model...") | |
ocr_model = AutoModelForImageTextToText.from_pretrained( | |
model_path, | |
torch_dtype="auto", | |
device_map="auto", | |
trust_remote_code=True | |
) | |
ocr_model.eval() | |
ocr_processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) | |
ocr_tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) | |
print("β Nanonets OCR model loaded successfully!") | |
except Exception as e: | |
print(f"β Failed to load Nanonets OCR model: {e}") | |
print("π Falling back to pytesseract for OCR") | |
# Helper functions | |
def save_index(): | |
try: | |
faiss.write_index(index, INDEX_PATH) | |
with open(LABELS_PATH, "wb") as f: | |
pickle.dump(labels, f) | |
except Exception as e: | |
print(f"β Failed to save index: {e}") | |
def authenticate_user(username: str, password: str): | |
conn = sqlite3.connect(DATABASE_PATH) | |
cursor = conn.cursor() | |
cursor.execute('SELECT password_hash FROM users WHERE username = ? AND is_active = TRUE', (username,)) | |
result = cursor.fetchone() | |
conn.close() | |
if result and check_password_hash(result[0], password): | |
return {"username": username} | |
return None | |
def create_access_token(data: dict): | |
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) | |
to_encode = data.copy() | |
to_encode.update({"exp": expire}) | |
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) | |
def verify_token(token: str): | |
try: | |
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) | |
username = payload.get("sub") | |
return username if username else None | |
except jwt.PyJWTError: | |
return None | |
def image_from_pdf(pdf_bytes): | |
try: | |
images = convert_from_bytes(pdf_bytes, dpi=200) | |
return images[0] | |
except Exception as e: | |
print(f"β PDF conversion error: {e}") | |
return None | |
def process_tags(content: str) -> str: | |
"""Process special tags from Nanonets OCR output""" | |
content = content.replace("<img>", "<img>") | |
content = content.replace("</img>", "</img>") | |
content = content.replace("<watermark>", "<watermark>") | |
content = content.replace("</watermark>", "</watermark>") | |
content = content.replace("<page_number>", "<page_number>") | |
content = content.replace("</page_number>", "</page_number>") | |
content = content.replace("<signature>", "<signature>") | |
content = content.replace("</signature>", "</signature>") | |
return content | |
def encode_image(image: Image) -> str: | |
"""Encode image to base64 for Nanonets OCR""" | |
buffered = BytesIO() | |
image.save(buffered, format="JPEG") | |
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
return img_str | |
def nanonets_ocr_extract(image): | |
"""Extract text using Nanonets OCR model""" | |
try: | |
if ocr_model is None or ocr_processor is None or ocr_tokenizer is None: | |
# Fallback to py tesseract | |
return extract_text_pytesseract(image) | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
# Resize image for optimal processing | |
image = image.resize((2048, 2048)) | |
# Prepare prompt for OCR extraction | |
user_prompt = """Extract the text from the above document as if you were reading it naturally. Return the tables in html format. Return the equations in LaTeX representation. Watermarks should be wrapped in brackets. Ex: <watermark>OFFICIAL COPY</watermark>. Page numbers should be wrapped in brackets. Ex: <page_number>14</page_number> or <page_number>9/22</page_number>. Prefer using β and β for check boxes.""" | |
# Format messages for the model | |
formatted_messages = [ | |
{"role": "system", "content": "You are a helpful assistant."}, | |
{"role": "user", "content": [ | |
{"type": "image", "image": image}, | |
{"type": "text", "text": user_prompt}, | |
]}, | |
] | |
# Apply chat template | |
text = ocr_processor.apply_chat_template( | |
formatted_messages, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
# Process inputs | |
inputs = ocr_processor( | |
text=[text], | |
images=[image], | |
padding=True, | |
return_tensors="pt" | |
) | |
# Move inputs to model device | |
inputs = {k: v.to(ocr_model.device) if hasattr(v, 'to') else v for k, v in inputs.items()} | |
# Generate text | |
with torch.no_grad(): | |
generated_ids = ocr_model.generate( | |
**inputs, | |
max_new_tokens=4096, | |
do_sample=False, | |
pad_token_id=ocr_tokenizer.eos_token_id, | |
) | |
# Decode the generated text | |
generated_text = ocr_tokenizer.decode( | |
generated_ids[0][inputs['input_ids'].shape[1]:], | |
skip_special_tokens=True | |
) | |
# Process special tags | |
processed_text = process_tags(generated_text) | |
return processed_text.strip() if processed_text.strip() else "β No text detected" | |
except Exception as e: | |
print(f"β Nanonets OCR error: {e}") | |
# Fallback to pytesseract | |
return extract_text_pytesseract(image) | |
def extract_text_pytesseract(image): | |
"""Fallback OCR using pytesseract""" | |
try: | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
custom_config = r'--oem 3 --psm 6' | |
text = pytesseract.image_to_string(image, config=custom_config) | |
return text.strip() if text.strip() else "β No text detected" | |
except Exception as e: | |
return f"β OCR error: {str(e)}" | |
def extract_text(image): | |
"""Main OCR function - tries Nanonets first, falls back to pytesseract""" | |
if ocr_model is not None: | |
return nanonets_ocr_extract(image) | |
else: | |
return extract_text_pytesseract(image) | |
def get_clip_embedding(image): | |
try: | |
if clip_model is None: | |
return None | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
image_input = preprocess(image).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
image_features = clip_model.encode_image(image_input) | |
image_features = image_features / image_features.norm(dim=-1, keepdim=True) | |
return image_features.cpu().numpy()[0] | |
except Exception as e: | |
print(f"β CLIP embedding error: {e}") | |
return None | |
def save_uploaded_file(file_content: bytes, filename: str) -> str: | |
file_id = str(uuid.uuid4()) | |
file_extension = os.path.splitext(filename)[1] | |
saved_filename = f"{file_id}{file_extension}" | |
file_path = os.path.join(UPLOADS_DIR, saved_filename) | |
with open(file_path, 'wb') as f: | |
f.write(file_content) | |
return saved_filename | |
# Routes | |
def dashboard(): | |
return send_from_directory('static', 'index.html') | |
def static_files(filename): | |
return send_from_directory('static', filename) | |
def login(): | |
username = request.form.get("username") | |
password = request.form.get("password") | |
user = authenticate_user(username, password) | |
if not user: | |
return jsonify({"detail": "Incorrect username or password"}), 401 | |
access_token = create_access_token(data={"sub": user["username"]}) | |
return jsonify({"access_token": access_token, "token_type": "bearer", "username": user["username"]}) | |
def upload_category(): | |
# Verify token | |
auth_header = request.headers.get('Authorization') | |
if not auth_header or not auth_header.startswith('Bearer '): | |
return jsonify({"error": "Missing or invalid token"}), 401 | |
token = auth_header.split(' ')[1] | |
username = verify_token(token) | |
if not username: | |
return jsonify({"error": "Invalid token"}), 401 | |
try: | |
label = request.form.get("label") | |
file = request.files.get("file") | |
if not label or not file: | |
return jsonify({"error": "Missing label or file"}), 400 | |
file_content = file.read() | |
if file.content_type and file.content_type.startswith('application/pdf'): | |
image = image_from_pdf(file_content) | |
else: | |
image = Image.open(io.BytesIO(file_content)) | |
if image is None: | |
return jsonify({"error": "Failed to process image"}), 400 | |
embedding = get_clip_embedding(image) | |
if embedding is None: | |
return jsonify({"error": "Failed to generate embedding"}), 400 | |
index.add(np.array([embedding])) | |
labels.append(label.strip()) | |
save_index() | |
return jsonify({"message": f"β Added category '{label}' (Total: {len(labels)} categories)", "status": "success"}) | |
except Exception as e: | |
return jsonify({"error": str(e)}), 500 | |
def classify_document(): | |
# Verify token | |
auth_header = request.headers.get('Authorization') | |
if not auth_header or not auth_header.startswith('Bearer '): | |
return jsonify({"error": "Missing or invalid token"}), 401 | |
token = auth_header.split(' ')[1] | |
username = verify_token(token) | |
if not username: | |
return jsonify({"error": "Invalid token"}), 401 | |
try: | |
if len(labels) == 0: | |
return jsonify({"error": "No categories in database. Please add some first."}), 400 | |
file = request.files.get("file") | |
if not file: | |
return jsonify({"error": "Missing file"}), 400 | |
file_content = file.read() | |
if file.content_type and file.content_type.startswith('application/pdf'): | |
image = image_from_pdf(file_content) | |
else: | |
image = Image.open(io.BytesIO(file_content)) | |
if image is None: | |
return jsonify({"error": "Failed to process image."}), 400 | |
embedding = get_clip_embedding(image) | |
if embedding is None: | |
return jsonify({"error": "Failed to generate embedding"}), 400 | |
k = min(3, len(labels)) | |
D, I = index.search(np.array([embedding]), k=k) | |
if len(labels) > 0 and I[0][0] < len(labels): | |
# Convert numpy float32 to Python float for JSON serialization | |
similarity = float(1 - D[0][0]) | |
confidence_threshold = 0.35 | |
best_match = labels[I[0][0]] | |
matches = [] | |
for i in range(min(k, len(D[0]))): | |
if I[0][i] < len(labels): | |
# Convert numpy float32 to Python float | |
sim = float(1 - D[0][i]) | |
matches.append({"category": labels[I[0][i]], "similarity": round(sim, 3)}) | |
# Save classified document to SQLite with enhanced OCR | |
if similarity >= confidence_threshold: | |
saved_filename = save_uploaded_file(file_content, file.filename) | |
ocr_text = extract_text(image) # Now uses Nanonets OCR | |
document_id = str(uuid.uuid4()) | |
conn = sqlite3.connect(DATABASE_PATH) | |
cursor = conn.cursor() | |
cursor.execute(''' | |
INSERT INTO documents (id, filename, original_filename, category, similarity, ocr_text, upload_date, file_path) | |
VALUES (?, ?, ?, ?, ?, ?, ?, ?) | |
''', (document_id, saved_filename, file.filename, best_match, round(similarity, 3), | |
ocr_text, datetime.now().isoformat(), os.path.join(UPLOADS_DIR, saved_filename))) | |
conn.commit() | |
conn.close() | |
return jsonify({ | |
"status": "success", | |
"category": best_match, | |
"similarity": round(similarity, 3), | |
"confidence": "high", | |
"matches": matches, | |
"document_saved": True, | |
"document_id": document_id, | |
"ocr_preview": ocr_text[:200] + "..." if len(ocr_text) > 200 else ocr_text | |
}) | |
else: | |
return jsonify({ | |
"status": "low_confidence", | |
"category": best_match, | |
"similarity": round(similarity, 3), | |
"confidence": "low", | |
"matches": matches, | |
"document_saved": False | |
}) | |
return jsonify({"error": "Document not recognized"}), 400 | |
except Exception as e: | |
print(f"Classification error: {e}") | |
return jsonify({"error": str(e)}), 500 | |
def get_categories(): | |
# Verify token | |
auth_header = request.headers.get('Authorization') | |
if not auth_header or not auth_header.startswith('Bearer '): | |
return jsonify({"error": "Missing or invalid token"}), 401 | |
token = auth_header.split(' ')[1] | |
username = verify_token(token) | |
if not username: | |
return jsonify({"error": "Invalid token"}), 401 | |
categories = list(set(labels)) # Remove duplicates | |
category_counts = {} | |
for label in labels: | |
category_counts[label] = category_counts.get(label, 0) + 1 | |
return jsonify({"categories": categories, "counts": category_counts}) | |
def get_documents_by_category(category): | |
# Verify token | |
auth_header = request.headers.get('Authorization') | |
if not auth_header or not auth_header.startswith('Bearer '): | |
return jsonify({"error": "Missing or invalid token"}), 401 | |
token = auth_header.split(' ')[1] | |
username = verify_token(token) | |
if not username: | |
return jsonify({"error": "Invalid token"}), 401 | |
conn = sqlite3.connect(DATABASE_PATH) | |
cursor = conn.cursor() | |
cursor.execute('SELECT * FROM documents WHERE category = ? ORDER BY upload_date DESC', (category,)) | |
documents = [] | |
for row in cursor.fetchall(): | |
documents.append({ | |
"id": row[0], | |
"filename": row[1], | |
"original_filename": row[2], | |
"category": row[3], | |
"similarity": row[4], | |
"ocr_text": row[5], | |
"upload_date": row[6], | |
"file_path": row[7] | |
}) | |
conn.close() | |
return jsonify({"documents": documents, "count": len(documents)}) | |
def delete_document(document_id): | |
# Verify token | |
auth_header = request.headers.get('Authorization') | |
if not auth_header or not auth_header.startswith('Bearer '): | |
return jsonify({"error": "Missing or invalid token"}), 401 | |
token = auth_header.split(' ')[1] | |
username = verify_token(token) | |
if not username: | |
return jsonify({"error": "Invalid token"}), 401 | |
try: | |
conn = sqlite3.connect(DATABASE_PATH) | |
cursor = conn.cursor() | |
# Get document info first | |
cursor.execute('SELECT file_path FROM documents WHERE id = ?', (document_id,)) | |
result = cursor.fetchone() | |
if not result: | |
conn.close() | |
return jsonify({"error": "Document not found"}), 404 | |
file_path = result[0] | |
# Delete physical file | |
if file_path and os.path.exists(file_path): | |
os.remove(file_path) | |
# Delete from database | |
cursor.execute('DELETE FROM documents WHERE id = ?', (document_id,)) | |
conn.commit() | |
conn.close() | |
return jsonify({"message": "Document deleted successfully", "status": "success"}) | |
except Exception as e: | |
return jsonify({"error": str(e)}), 500 | |
def ocr_document(): | |
# Verify token | |
auth_header = request.headers.get('Authorization') | |
if not auth_header or not auth_header.startswith('Bearer '): | |
return jsonify({"error": "Missing or invalid token"}), 401 | |
token = auth_header.split(' ')[1] | |
username = verify_token(token) | |
if not username: | |
return jsonify({"error": "Invalid token"}), 401 | |
try: | |
file = request.files.get("file") | |
if not file: | |
return jsonify({"error": "Missing file"}), 400 | |
file_content = file.read() | |
if file.content_type and file.content_type.startswith('application/pdf'): | |
image = image_from_pdf(file_content) | |
else: | |
image = Image.open(io.BytesIO(file_content)) | |
if image is None: | |
return jsonify({"error": "Failed to process image"}), 400 | |
# Use enhanced Nanonets OCR | |
text = extract_text(image) | |
# Determine OCR method used | |
ocr_method = "Nanonets OCR-s" if ocr_model is not None else "Pytesseract" | |
return jsonify({ | |
"text": text, | |
"status": "success", | |
"ocr_method": ocr_method, | |
"enhanced_features": ocr_model is not None | |
}) | |
except Exception as e: | |
return jsonify({"error": str(e)}), 500 | |
def get_stats(): | |
# Verify token | |
auth_header = request.headers.get('Authorization') | |
if not auth_header or not auth_header.startswith('Bearer '): | |
return jsonify({"error": "Missing or invalid token"}), 401 | |
token = auth_header.split(' ')[1] | |
username = verify_token(token) | |
if not username: | |
return jsonify({"error": "Invalid token"}), 401 | |
conn = sqlite3.connect(DATABASE_PATH) | |
cursor = conn.cursor() | |
cursor.execute('SELECT category, COUNT(*) FROM documents GROUP BY category') | |
category_stats = dict(cursor.fetchall()) | |
cursor.execute('SELECT COUNT(*) FROM documents') | |
total_documents = cursor.fetchone()[0] | |
conn.close() | |
return jsonify({ | |
"total_categories": len(set(labels)), | |
"total_documents": total_documents, | |
"category_distribution": category_stats | |
}) | |
def get_document_preview(document_id): | |
# Verify token | |
auth_header = request.headers.get('Authorization') | |
if not auth_header or not auth_header.startswith('Bearer '): | |
# For image requests, try to get token from query params as fallback | |
token = request.args.get('token') | |
if not token: | |
return jsonify({"error": "Missing or invalid token"}), 401 | |
username = verify_token(token) | |
else: | |
token = auth_header.split(' ')[1] | |
username = verify_token(token) | |
if not username: | |
return jsonify({"error": "Invalid token"}), 401 | |
try: | |
conn = sqlite3.connect(DATABASE_PATH) | |
cursor = conn.cursor() | |
cursor.execute('SELECT file_path FROM documents WHERE id = ?', (document_id,)) | |
result = cursor.fetchone() | |
conn.close() | |
if not result: | |
return jsonify({"error": "Document not found"}), 404 | |
file_path = result[0] | |
if not os.path.exists(file_path): | |
return jsonify({"error": "File not found"}), 404 | |
return send_from_directory(os.path.dirname(file_path), os.path.basename(file_path)) | |
except Exception as e: | |
return jsonify({"error": str(e)}), 500 | |
if __name__ == "__main__": | |
app.run(host="0.0.0.0", port=7860, debug=True) | |