from flask import Flask, request, jsonify, send_from_directory from werkzeug.utils import secure_filename from werkzeug.security import generate_password_hash, check_password_hash import pytesseract from PIL import Image import numpy as np import faiss import os import pickle from pdf2image import convert_from_bytes import torch import clip import io import json import uuid from datetime import datetime, timedelta import jwt import sqlite3 import tempfile import base64 from io import BytesIO from transformers import AutoModelForImageTextToText, AutoProcessor, AutoTokenizer, TextIteratorStreamer from threading import Thread import time app = Flask(__name__) app.config['SECRET_KEY'] = 'your-secret-key-change-this-in-production' # Security configuration SECRET_KEY = "your-secret-key-change-this-in-production" ALGORITHM = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES = 30 # Set CLIP cache to writable directory os.environ['CLIP_CACHE'] = '/app/clip_cache' os.makedirs('/app/clip_cache', exist_ok=True) # Directories INDEX_PATH = "data/index.faiss" LABELS_PATH = "data/labels.pkl" DATABASE_PATH = "data/documents.db" UPLOADS_DIR = "data/uploads" os.makedirs("data", exist_ok=True) os.makedirs("static", exist_ok=True) os.makedirs(UPLOADS_DIR, exist_ok=True) # Initialize database def init_db(): conn = sqlite3.connect(DATABASE_PATH) cursor = conn.cursor() # Users table cursor.execute(''' CREATE TABLE IF NOT EXISTS users ( id INTEGER PRIMARY KEY AUTOINCREMENT, username TEXT UNIQUE NOT NULL, password_hash TEXT NOT NULL, is_active BOOLEAN DEFAULT TRUE ) ''') # Documents table cursor.execute(''' CREATE TABLE IF NOT EXISTS documents ( id TEXT PRIMARY KEY, filename TEXT NOT NULL, original_filename TEXT NOT NULL, category TEXT NOT NULL, similarity REAL NOT NULL, ocr_text TEXT, upload_date TEXT NOT NULL, file_path TEXT NOT NULL ) ''') # Insert default admin user if not exists cursor.execute('SELECT * FROM users WHERE username = ?', ('admin',)) if not cursor.fetchone(): admin_hash = generate_password_hash('admin123') cursor.execute('INSERT INTO users (username, password_hash) VALUES (?, ?)', ('admin', admin_hash)) conn.commit() conn.close() init_db() # Initialize index and labels index = faiss.IndexFlatL2(512) labels = [] if os.path.exists(INDEX_PATH) and os.path.exists(LABELS_PATH): try: index = faiss.read_index(INDEX_PATH) with open(LABELS_PATH, "rb") as f: labels = pickle.load(f) print(f"✅ Loaded existing index with {len(labels)} labels") except Exception as e: print(f"⚠️ Failed to load existing index: {e}") if os.path.exists(INDEX_PATH): os.remove(INDEX_PATH) if os.path.exists(LABELS_PATH): os.remove(LABELS_PATH) # Initialize CLIP model with custom cache device = "cuda" if torch.cuda.is_available() else "cpu" try: clip_model, preprocess = clip.load("ViT-B/32", device=device, download_root='/app/clip_cache') print("✅ CLIP model loaded successfully") except Exception as e: print(f"❌ Failed to load CLIP model: {e}") # Fallback initialization clip_model = None preprocess = None # Initialize Nanonets OCR model ocr_model = None ocr_processor = None ocr_tokenizer = None try: model_path = "nanonets/Nanonets-OCR-s" print("Loading Nanonets OCR model...") ocr_model = AutoModelForImageTextToText.from_pretrained( model_path, torch_dtype="auto", device_map="auto", trust_remote_code=True ) ocr_model.eval() ocr_processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) ocr_tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) print("✅ Nanonets OCR model loaded successfully!") except Exception as e: print(f"❌ Failed to load Nanonets OCR model: {e}") print("📝 Falling back to pytesseract for OCR") # Helper functions def save_index(): try: faiss.write_index(index, INDEX_PATH) with open(LABELS_PATH, "wb") as f: pickle.dump(labels, f) except Exception as e: print(f"❌ Failed to save index: {e}") def authenticate_user(username: str, password: str): conn = sqlite3.connect(DATABASE_PATH) cursor = conn.cursor() cursor.execute('SELECT password_hash FROM users WHERE username = ? AND is_active = TRUE', (username,)) result = cursor.fetchone() conn.close() if result and check_password_hash(result[0], password): return {"username": username} return None def create_access_token(data: dict): expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) to_encode = data.copy() to_encode.update({"exp": expire}) return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) def verify_token(token: str): try: payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) username = payload.get("sub") return username if username else None except jwt.PyJWTError: return None def image_from_pdf(pdf_bytes): try: images = convert_from_bytes(pdf_bytes, dpi=200) return images[0] except Exception as e: print(f"❌ PDF conversion error: {e}") return None def process_tags(content: str) -> str: """Process special tags from Nanonets OCR output""" content = content.replace("", "<img>") content = content.replace("", "</img>") content = content.replace("", "<watermark>") content = content.replace("", "</watermark>") content = content.replace("", "<page_number>") content = content.replace("", "</page_number>") content = content.replace("", "<signature>") content = content.replace("", "</signature>") return content def encode_image(image: Image) -> str: """Encode image to base64 for Nanonets OCR""" buffered = BytesIO() image.save(buffered, format="JPEG") img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") return img_str def nanonets_ocr_extract(image): """Extract text using Nanonets OCR model""" try: if ocr_model is None or ocr_processor is None or ocr_tokenizer is None: # Fallback to py tesseract return extract_text_pytesseract(image) if image.mode != 'RGB': image = image.convert('RGB') # Resize image for optimal processing image = image.resize((2048, 2048)) # Prepare prompt for OCR extraction user_prompt = """Extract the text from the above document as if you were reading it naturally. Return the tables in html format. Return the equations in LaTeX representation. Watermarks should be wrapped in brackets. Ex: OFFICIAL COPY. Page numbers should be wrapped in brackets. Ex: 14 or 9/22. Prefer using ☐ and ☑ for check boxes.""" # Format messages for the model formatted_messages = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": user_prompt}, ]}, ] # Apply chat template text = ocr_processor.apply_chat_template( formatted_messages, tokenize=False, add_generation_prompt=True ) # Process inputs inputs = ocr_processor( text=[text], images=[image], padding=True, return_tensors="pt" ) # Move inputs to model device inputs = {k: v.to(ocr_model.device) if hasattr(v, 'to') else v for k, v in inputs.items()} # Generate text with torch.no_grad(): generated_ids = ocr_model.generate( **inputs, max_new_tokens=4096, do_sample=False, pad_token_id=ocr_tokenizer.eos_token_id, ) # Decode the generated text generated_text = ocr_tokenizer.decode( generated_ids[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True ) # Process special tags processed_text = process_tags(generated_text) return processed_text.strip() if processed_text.strip() else "❓ No text detected" except Exception as e: print(f"❌ Nanonets OCR error: {e}") # Fallback to pytesseract return extract_text_pytesseract(image) def extract_text_pytesseract(image): """Fallback OCR using pytesseract""" try: if image.mode != 'RGB': image = image.convert('RGB') custom_config = r'--oem 3 --psm 6' text = pytesseract.image_to_string(image, config=custom_config) return text.strip() if text.strip() else "❓ No text detected" except Exception as e: return f"❌ OCR error: {str(e)}" def extract_text(image): """Main OCR function - tries Nanonets first, falls back to pytesseract""" if ocr_model is not None: return nanonets_ocr_extract(image) else: return extract_text_pytesseract(image) def get_clip_embedding(image): try: if clip_model is None: return None if image.mode != 'RGB': image = image.convert('RGB') image_input = preprocess(image).unsqueeze(0).to(device) with torch.no_grad(): image_features = clip_model.encode_image(image_input) image_features = image_features / image_features.norm(dim=-1, keepdim=True) return image_features.cpu().numpy()[0] except Exception as e: print(f"❌ CLIP embedding error: {e}") return None def save_uploaded_file(file_content: bytes, filename: str) -> str: file_id = str(uuid.uuid4()) file_extension = os.path.splitext(filename)[1] saved_filename = f"{file_id}{file_extension}" file_path = os.path.join(UPLOADS_DIR, saved_filename) with open(file_path, 'wb') as f: f.write(file_content) return saved_filename # Routes @app.route("/") def dashboard(): return send_from_directory('static', 'index.html') @app.route("/static/") def static_files(filename): return send_from_directory('static', filename) @app.route("/api/login", methods=["POST"]) def login(): username = request.form.get("username") password = request.form.get("password") user = authenticate_user(username, password) if not user: return jsonify({"detail": "Incorrect username or password"}), 401 access_token = create_access_token(data={"sub": user["username"]}) return jsonify({"access_token": access_token, "token_type": "bearer", "username": user["username"]}) @app.route("/api/upload-category", methods=["POST"]) def upload_category(): # Verify token auth_header = request.headers.get('Authorization') if not auth_header or not auth_header.startswith('Bearer '): return jsonify({"error": "Missing or invalid token"}), 401 token = auth_header.split(' ')[1] username = verify_token(token) if not username: return jsonify({"error": "Invalid token"}), 401 try: label = request.form.get("label") file = request.files.get("file") if not label or not file: return jsonify({"error": "Missing label or file"}), 400 file_content = file.read() if file.content_type and file.content_type.startswith('application/pdf'): image = image_from_pdf(file_content) else: image = Image.open(io.BytesIO(file_content)) if image is None: return jsonify({"error": "Failed to process image"}), 400 embedding = get_clip_embedding(image) if embedding is None: return jsonify({"error": "Failed to generate embedding"}), 400 index.add(np.array([embedding])) labels.append(label.strip()) save_index() return jsonify({"message": f"✅ Added category '{label}' (Total: {len(labels)} categories)", "status": "success"}) except Exception as e: return jsonify({"error": str(e)}), 500 @app.route("/api/classify-document", methods=["POST"]) def classify_document(): # Verify token auth_header = request.headers.get('Authorization') if not auth_header or not auth_header.startswith('Bearer '): return jsonify({"error": "Missing or invalid token"}), 401 token = auth_header.split(' ')[1] username = verify_token(token) if not username: return jsonify({"error": "Invalid token"}), 401 try: if len(labels) == 0: return jsonify({"error": "No categories in database. Please add some first."}), 400 file = request.files.get("file") if not file: return jsonify({"error": "Missing file"}), 400 file_content = file.read() if file.content_type and file.content_type.startswith('application/pdf'): image = image_from_pdf(file_content) else: image = Image.open(io.BytesIO(file_content)) if image is None: return jsonify({"error": "Failed to process image."}), 400 embedding = get_clip_embedding(image) if embedding is None: return jsonify({"error": "Failed to generate embedding"}), 400 k = min(3, len(labels)) D, I = index.search(np.array([embedding]), k=k) if len(labels) > 0 and I[0][0] < len(labels): # Convert numpy float32 to Python float for JSON serialization similarity = float(1 - D[0][0]) confidence_threshold = 0.35 best_match = labels[I[0][0]] matches = [] for i in range(min(k, len(D[0]))): if I[0][i] < len(labels): # Convert numpy float32 to Python float sim = float(1 - D[0][i]) matches.append({"category": labels[I[0][i]], "similarity": round(sim, 3)}) # Save classified document to SQLite with enhanced OCR if similarity >= confidence_threshold: saved_filename = save_uploaded_file(file_content, file.filename) ocr_text = extract_text(image) # Now uses Nanonets OCR document_id = str(uuid.uuid4()) conn = sqlite3.connect(DATABASE_PATH) cursor = conn.cursor() cursor.execute(''' INSERT INTO documents (id, filename, original_filename, category, similarity, ocr_text, upload_date, file_path) VALUES (?, ?, ?, ?, ?, ?, ?, ?) ''', (document_id, saved_filename, file.filename, best_match, round(similarity, 3), ocr_text, datetime.now().isoformat(), os.path.join(UPLOADS_DIR, saved_filename))) conn.commit() conn.close() return jsonify({ "status": "success", "category": best_match, "similarity": round(similarity, 3), "confidence": "high", "matches": matches, "document_saved": True, "document_id": document_id, "ocr_preview": ocr_text[:200] + "..." if len(ocr_text) > 200 else ocr_text }) else: return jsonify({ "status": "low_confidence", "category": best_match, "similarity": round(similarity, 3), "confidence": "low", "matches": matches, "document_saved": False }) return jsonify({"error": "Document not recognized"}), 400 except Exception as e: print(f"Classification error: {e}") return jsonify({"error": str(e)}), 500 @app.route("/api/categories", methods=["GET"]) def get_categories(): # Verify token auth_header = request.headers.get('Authorization') if not auth_header or not auth_header.startswith('Bearer '): return jsonify({"error": "Missing or invalid token"}), 401 token = auth_header.split(' ')[1] username = verify_token(token) if not username: return jsonify({"error": "Invalid token"}), 401 categories = list(set(labels)) # Remove duplicates category_counts = {} for label in labels: category_counts[label] = category_counts.get(label, 0) + 1 return jsonify({"categories": categories, "counts": category_counts}) @app.route("/api/documents/", methods=["GET"]) def get_documents_by_category(category): # Verify token auth_header = request.headers.get('Authorization') if not auth_header or not auth_header.startswith('Bearer '): return jsonify({"error": "Missing or invalid token"}), 401 token = auth_header.split(' ')[1] username = verify_token(token) if not username: return jsonify({"error": "Invalid token"}), 401 conn = sqlite3.connect(DATABASE_PATH) cursor = conn.cursor() cursor.execute('SELECT * FROM documents WHERE category = ? ORDER BY upload_date DESC', (category,)) documents = [] for row in cursor.fetchall(): documents.append({ "id": row[0], "filename": row[1], "original_filename": row[2], "category": row[3], "similarity": row[4], "ocr_text": row[5], "upload_date": row[6], "file_path": row[7] }) conn.close() return jsonify({"documents": documents, "count": len(documents)}) @app.route("/api/documents/", methods=["DELETE"]) def delete_document(document_id): # Verify token auth_header = request.headers.get('Authorization') if not auth_header or not auth_header.startswith('Bearer '): return jsonify({"error": "Missing or invalid token"}), 401 token = auth_header.split(' ')[1] username = verify_token(token) if not username: return jsonify({"error": "Invalid token"}), 401 try: conn = sqlite3.connect(DATABASE_PATH) cursor = conn.cursor() # Get document info first cursor.execute('SELECT file_path FROM documents WHERE id = ?', (document_id,)) result = cursor.fetchone() if not result: conn.close() return jsonify({"error": "Document not found"}), 404 file_path = result[0] # Delete physical file if file_path and os.path.exists(file_path): os.remove(file_path) # Delete from database cursor.execute('DELETE FROM documents WHERE id = ?', (document_id,)) conn.commit() conn.close() return jsonify({"message": "Document deleted successfully", "status": "success"}) except Exception as e: return jsonify({"error": str(e)}), 500 @app.route("/api/ocr", methods=["POST"]) def ocr_document(): # Verify token auth_header = request.headers.get('Authorization') if not auth_header or not auth_header.startswith('Bearer '): return jsonify({"error": "Missing or invalid token"}), 401 token = auth_header.split(' ')[1] username = verify_token(token) if not username: return jsonify({"error": "Invalid token"}), 401 try: file = request.files.get("file") if not file: return jsonify({"error": "Missing file"}), 400 file_content = file.read() if file.content_type and file.content_type.startswith('application/pdf'): image = image_from_pdf(file_content) else: image = Image.open(io.BytesIO(file_content)) if image is None: return jsonify({"error": "Failed to process image"}), 400 # Use enhanced Nanonets OCR text = extract_text(image) # Determine OCR method used ocr_method = "Nanonets OCR-s" if ocr_model is not None else "Pytesseract" return jsonify({ "text": text, "status": "success", "ocr_method": ocr_method, "enhanced_features": ocr_model is not None }) except Exception as e: return jsonify({"error": str(e)}), 500 @app.route("/api/stats", methods=["GET"]) def get_stats(): # Verify token auth_header = request.headers.get('Authorization') if not auth_header or not auth_header.startswith('Bearer '): return jsonify({"error": "Missing or invalid token"}), 401 token = auth_header.split(' ')[1] username = verify_token(token) if not username: return jsonify({"error": "Invalid token"}), 401 conn = sqlite3.connect(DATABASE_PATH) cursor = conn.cursor() cursor.execute('SELECT category, COUNT(*) FROM documents GROUP BY category') category_stats = dict(cursor.fetchall()) cursor.execute('SELECT COUNT(*) FROM documents') total_documents = cursor.fetchone()[0] conn.close() return jsonify({ "total_categories": len(set(labels)), "total_documents": total_documents, "category_distribution": category_stats }) @app.route("/api/document-preview/", methods=["GET"]) def get_document_preview(document_id): # Verify token auth_header = request.headers.get('Authorization') if not auth_header or not auth_header.startswith('Bearer '): # For image requests, try to get token from query params as fallback token = request.args.get('token') if not token: return jsonify({"error": "Missing or invalid token"}), 401 username = verify_token(token) else: token = auth_header.split(' ')[1] username = verify_token(token) if not username: return jsonify({"error": "Invalid token"}), 401 try: conn = sqlite3.connect(DATABASE_PATH) cursor = conn.cursor() cursor.execute('SELECT file_path FROM documents WHERE id = ?', (document_id,)) result = cursor.fetchone() conn.close() if not result: return jsonify({"error": "Document not found"}), 404 file_path = result[0] if not os.path.exists(file_path): return jsonify({"error": "File not found"}), 404 return send_from_directory(os.path.dirname(file_path), os.path.basename(file_path)) except Exception as e: return jsonify({"error": str(e)}), 500 if __name__ == "__main__": app.run(host="0.0.0.0", port=7860, debug=True)