Spaces:
Sleeping
Sleeping
| from flask import Flask, request, jsonify, send_from_directory | |
| from werkzeug.utils import secure_filename | |
| from werkzeug.security import generate_password_hash, check_password_hash | |
| import pytesseract | |
| from PIL import Image | |
| import numpy as np | |
| import faiss | |
| import os | |
| import pickle | |
| from pdf2image import convert_from_bytes | |
| import torch | |
| import clip | |
| import io | |
| import json | |
| import uuid | |
| from datetime import datetime, timedelta | |
| import jwt | |
| import sqlite3 | |
| import tempfile | |
| app = Flask(__name__) | |
| app.config['SECRET_KEY'] = 'your-secret-key-change-this-in-production' | |
| # Security configuration | |
| SECRET_KEY = "your-secret-key-change-this-in-production" | |
| ALGORITHM = "HS256" | |
| ACCESS_TOKEN_EXPIRE_MINUTES = 30 | |
| # Set CLIP cache to writable directory | |
| os.environ['CLIP_CACHE'] = '/app/clip_cache' | |
| os.makedirs('/app/clip_cache', exist_ok=True) | |
| # Directories | |
| INDEX_PATH = "data/index.faiss" | |
| LABELS_PATH = "data/labels.pkl" | |
| DATABASE_PATH = "data/documents.db" | |
| UPLOADS_DIR = "data/uploads" | |
| os.makedirs("data", exist_ok=True) | |
| os.makedirs("static", exist_ok=True) | |
| os.makedirs(UPLOADS_DIR, exist_ok=True) | |
| # Initialize database | |
| def init_db(): | |
| conn = sqlite3.connect(DATABASE_PATH) | |
| cursor = conn.cursor() | |
| # Users table | |
| cursor.execute(''' | |
| CREATE TABLE IF NOT EXISTS users ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| username TEXT UNIQUE NOT NULL, | |
| password_hash TEXT NOT NULL, | |
| is_active BOOLEAN DEFAULT TRUE | |
| ) | |
| ''') | |
| # Documents table | |
| cursor.execute(''' | |
| CREATE TABLE IF NOT EXISTS documents ( | |
| id TEXT PRIMARY KEY, | |
| filename TEXT NOT NULL, | |
| original_filename TEXT NOT NULL, | |
| category TEXT NOT NULL, | |
| similarity REAL NOT NULL, | |
| ocr_text TEXT, | |
| upload_date TEXT NOT NULL, | |
| file_path TEXT NOT NULL | |
| ) | |
| ''') | |
| # Insert default admin user if not exists | |
| cursor.execute('SELECT * FROM users WHERE username = ?', ('admin',)) | |
| if not cursor.fetchone(): | |
| admin_hash = generate_password_hash('admin123') | |
| cursor.execute('INSERT INTO users (username, password_hash) VALUES (?, ?)', | |
| ('admin', admin_hash)) | |
| conn.commit() | |
| conn.close() | |
| init_db() | |
| # Initialize index and labels | |
| index = faiss.IndexFlatL2(512) | |
| labels = [] | |
| if os.path.exists(INDEX_PATH) and os.path.exists(LABELS_PATH): | |
| try: | |
| index = faiss.read_index(INDEX_PATH) | |
| with open(LABELS_PATH, "rb") as f: | |
| labels = pickle.load(f) | |
| print(f"β Loaded existing index with {len(labels)} labels") | |
| except Exception as e: | |
| print(f"β οΈ Failed to load existing index: {e}") | |
| if os.path.exists(INDEX_PATH): | |
| os.remove(INDEX_PATH) | |
| if os.path.exists(LABELS_PATH): | |
| os.remove(LABELS_PATH) | |
| # Initialize CLIP model with custom cache | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| try: | |
| clip_model, preprocess = clip.load("ViT-B/32", device=device, download_root='/app/clip_cache') | |
| print("β CLIP model loaded successfully") | |
| except Exception as e: | |
| print(f"β Failed to load CLIP model: {e}") | |
| # Fallback initialization | |
| clip_model = None | |
| preprocess = None | |
| # Helper functions | |
| def save_index(): | |
| try: | |
| faiss.write_index(index, INDEX_PATH) | |
| with open(LABELS_PATH, "wb") as f: | |
| pickle.dump(labels, f) | |
| except Exception as e: | |
| print(f"β Failed to save index: {e}") | |
| def authenticate_user(username: str, password: str): | |
| conn = sqlite3.connect(DATABASE_PATH) | |
| cursor = conn.cursor() | |
| cursor.execute('SELECT password_hash FROM users WHERE username = ? AND is_active = TRUE', (username,)) | |
| result = cursor.fetchone() | |
| conn.close() | |
| if result and check_password_hash(result[0], password): | |
| return {"username": username} | |
| return None | |
| def create_access_token(data: dict): | |
| expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) | |
| to_encode = data.copy() | |
| to_encode.update({"exp": expire}) | |
| return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) | |
| def verify_token(token: str): | |
| try: | |
| payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) | |
| username = payload.get("sub") | |
| return username if username else None | |
| except jwt.PyJWTError: | |
| return None | |
| def image_from_pdf(pdf_bytes): | |
| try: | |
| images = convert_from_bytes(pdf_bytes, dpi=200) | |
| return images[0] | |
| except Exception as e: | |
| print(f"β PDF conversion error: {e}") | |
| return None | |
| def extract_text(image): | |
| try: | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| custom_config = r'--oem 3 --psm 6' | |
| text = pytesseract.image_to_string(image, config=custom_config) | |
| return text.strip() if text.strip() else "β No text detected" | |
| except Exception as e: | |
| return f"β OCR error: {str(e)}" | |
| def get_clip_embedding(image): | |
| try: | |
| if clip_model is None: | |
| return None | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| image_input = preprocess(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| image_features = clip_model.encode_image(image_input) | |
| image_features = image_features / image_features.norm(dim=-1, keepdim=True) | |
| return image_features.cpu().numpy()[0] | |
| except Exception as e: | |
| print(f"β CLIP embedding error: {e}") | |
| return None | |
| def save_uploaded_file(file_content: bytes, filename: str) -> str: | |
| file_id = str(uuid.uuid4()) | |
| file_extension = os.path.splitext(filename)[1] | |
| saved_filename = f"{file_id}{file_extension}" | |
| file_path = os.path.join(UPLOADS_DIR, saved_filename) | |
| with open(file_path, 'wb') as f: | |
| f.write(file_content) | |
| return saved_filename | |
| # Routes | |
| def dashboard(): | |
| return send_from_directory('static', 'index.html') | |
| def static_files(filename): | |
| return send_from_directory('static', filename) | |
| def login(): | |
| username = request.form.get("username") | |
| password = request.form.get("password") | |
| user = authenticate_user(username, password) | |
| if not user: | |
| return jsonify({"detail": "Incorrect username or password"}), 401 | |
| access_token = create_access_token(data={"sub": user["username"]}) | |
| return jsonify({"access_token": access_token, "token_type": "bearer", "username": user["username"]}) | |
| def upload_category(): | |
| # Verify token | |
| auth_header = request.headers.get('Authorization') | |
| if not auth_header or not auth_header.startswith('Bearer '): | |
| return jsonify({"error": "Missing or invalid token"}), 401 | |
| token = auth_header.split(' ')[1] | |
| username = verify_token(token) | |
| if not username: | |
| return jsonify({"error": "Invalid token"}), 401 | |
| try: | |
| label = request.form.get("label") | |
| file = request.files.get("file") | |
| if not label or not file: | |
| return jsonify({"error": "Missing label or file"}), 400 | |
| file_content = file.read() | |
| if file.content_type and file.content_type.startswith('application/pdf'): | |
| image = image_from_pdf(file_content) | |
| else: | |
| image = Image.open(io.BytesIO(file_content)) | |
| if image is None: | |
| return jsonify({"error": "Failed to process image"}), 400 | |
| embedding = get_clip_embedding(image) | |
| if embedding is None: | |
| return jsonify({"error": "Failed to generate embedding"}), 400 | |
| index.add(np.array([embedding])) | |
| labels.append(label.strip()) | |
| save_index() | |
| return jsonify({"message": f"β Added category '{label}' (Total: {len(labels)} categories)", "status": "success"}) | |
| except Exception as e: | |
| return jsonify({"error": str(e)}), 500 | |
| def classify_document(): | |
| # Verify token | |
| auth_header = request.headers.get('Authorization') | |
| if not auth_header or not auth_header.startswith('Bearer '): | |
| return jsonify({"error": "Missing or invalid token"}), 401 | |
| token = auth_header.split(' ')[1] | |
| username = verify_token(token) | |
| if not username: | |
| return jsonify({"error": "Invalid token"}), 401 | |
| try: | |
| if len(labels) == 0: | |
| return jsonify({"error": "No categories in database. Please add some first."}), 400 | |
| file = request.files.get("file") | |
| if not file: | |
| return jsonify({"error": "Missing file"}), 400 | |
| file_content = file.read() | |
| if file.content_type and file.content_type.startswith('application/pdf'): | |
| image = image_from_pdf(file_content) | |
| else: | |
| image = Image.open(io.BytesIO(file_content)) | |
| if image is None: | |
| return jsonify({"error": "Failed to process image"}), 400 | |
| embedding = get_clip_embedding(image) | |
| if embedding is None: | |
| return jsonify({"error": "Failed to generate embedding"}), 400 | |
| k = min(3, len(labels)) | |
| D, I = index.search(np.array([embedding]), k=k) | |
| if len(labels) > 0 and I[0][0] < len(labels): | |
| similarity = 1 - D[0][0] | |
| confidence_threshold = 0.35 | |
| best_match = labels[I[0][0]] | |
| matches = [] | |
| for i in range(min(k, len(D[0]))): | |
| if I[0][i] < len(labels): | |
| sim = 1 - D[0][i] | |
| matches.append({"category": labels[I[0][i]], "similarity": round(sim, 3)}) | |
| # Save classified document to SQLite | |
| if similarity >= confidence_threshold: | |
| saved_filename = save_uploaded_file(file_content, file.filename) | |
| ocr_text = extract_text(image) | |
| document_id = str(uuid.uuid4()) | |
| conn = sqlite3.connect(DATABASE_PATH) | |
| cursor = conn.cursor() | |
| cursor.execute(''' | |
| INSERT INTO documents (id, filename, original_filename, category, similarity, ocr_text, upload_date, file_path) | |
| VALUES (?, ?, ?, ?, ?, ?, ?, ?) | |
| ''', (document_id, saved_filename, file.filename, best_match, round(similarity, 3), | |
| ocr_text, datetime.now().isoformat(), os.path.join(UPLOADS_DIR, saved_filename))) | |
| conn.commit() | |
| conn.close() | |
| return jsonify({ | |
| "status": "success", | |
| "category": best_match, | |
| "similarity": round(similarity, 3), | |
| "confidence": "high", | |
| "matches": matches, | |
| "document_saved": True, | |
| "document_id": document_id | |
| }) | |
| else: | |
| return jsonify({ | |
| "status": "low_confidence", | |
| "category": best_match, | |
| "similarity": round(similarity, 3), | |
| "confidence": "low", | |
| "matches": matches, | |
| "document_saved": False | |
| }) | |
| return jsonify({"error": "Document not recognized"}), 400 | |
| except Exception as e: | |
| return jsonify({"error": str(e)}), 500 | |
| def get_categories(): | |
| # Verify token | |
| auth_header = request.headers.get('Authorization') | |
| if not auth_header or not auth_header.startswith('Bearer '): | |
| return jsonify({"error": "Missing or invalid token"}), 401 | |
| token = auth_header.split(' ')[1] | |
| username = verify_token(token) | |
| if not username: | |
| return jsonify({"error": "Invalid token"}), 401 | |
| categories = list(set(labels)) # Remove duplicates | |
| category_counts = {} | |
| for label in labels: | |
| category_counts[label] = category_counts.get(label, 0) + 1 | |
| return jsonify({"categories": categories, "counts": category_counts}) | |
| def get_documents_by_category(category): | |
| # Verify token | |
| auth_header = request.headers.get('Authorization') | |
| if not auth_header or not auth_header.startswith('Bearer '): | |
| return jsonify({"error": "Missing or invalid token"}), 401 | |
| token = auth_header.split(' ')[1] | |
| username = verify_token(token) | |
| if not username: | |
| return jsonify({"error": "Invalid token"}), 401 | |
| conn = sqlite3.connect(DATABASE_PATH) | |
| cursor = conn.cursor() | |
| cursor.execute('SELECT * FROM documents WHERE category = ? ORDER BY upload_date DESC', (category,)) | |
| documents = [] | |
| for row in cursor.fetchall(): | |
| documents.append({ | |
| "id": row[0], | |
| "filename": row[1], | |
| "original_filename": row[2], | |
| "category": row[3], | |
| "similarity": row[4], | |
| "ocr_text": row[5], | |
| "upload_date": row[6], | |
| "file_path": row[7] | |
| }) | |
| conn.close() | |
| return jsonify({"documents": documents, "count": len(documents)}) | |
| def delete_document(document_id): | |
| # Verify token | |
| auth_header = request.headers.get('Authorization') | |
| if not auth_header or not auth_header.startswith('Bearer '): | |
| return jsonify({"error": "Missing or invalid token"}), 401 | |
| token = auth_header.split(' ')[1] | |
| username = verify_token(token) | |
| if not username: | |
| return jsonify({"error": "Invalid token"}), 401 | |
| try: | |
| conn = sqlite3.connect(DATABASE_PATH) | |
| cursor = conn.cursor() | |
| # Get document info first | |
| cursor.execute('SELECT file_path FROM documents WHERE id = ?', (document_id,)) | |
| result = cursor.fetchone() | |
| if not result: | |
| conn.close() | |
| return jsonify({"error": "Document not found"}), 404 | |
| file_path = result[0] | |
| # Delete physical file | |
| if file_path and os.path.exists(file_path): | |
| os.remove(file_path) | |
| # Delete from database | |
| cursor.execute('DELETE FROM documents WHERE id = ?', (document_id,)) | |
| conn.commit() | |
| conn.close() | |
| return jsonify({"message": "Document deleted successfully", "status": "success"}) | |
| except Exception as e: | |
| return jsonify({"error": str(e)}), 500 | |
| def ocr_document(): | |
| # Verify token | |
| auth_header = request.headers.get('Authorization') | |
| if not auth_header or not auth_header.startswith('Bearer '): | |
| return jsonify({"error": "Missing or invalid token"}), 401 | |
| token = auth_header.split(' ')[1] | |
| username = verify_token(token) | |
| if not username: | |
| return jsonify({"error": "Invalid token"}), 401 | |
| try: | |
| file = request.files.get("file") | |
| if not file: | |
| return jsonify({"error": "Missing file"}), 400 | |
| file_content = file.read() | |
| if file.content_type and file.content_type.startswith('application/pdf'): | |
| image = image_from_pdf(file_content) | |
| else: | |
| image = Image.open(io.BytesIO(file_content)) | |
| if image is None: | |
| return jsonify({"error": "Failed to process image"}), 400 | |
| text = extract_text(image) | |
| return jsonify({"text": text, "status": "success"}) | |
| except Exception as e: | |
| return jsonify({"error": str(e)}), 500 | |
| def get_stats(): | |
| # Verify token | |
| auth_header = request.headers.get('Authorization') | |
| if not auth_header or not auth_header.startswith('Bearer '): | |
| return jsonify({"error": "Missing or invalid token"}), 401 | |
| token = auth_header.split(' ')[1] | |
| username = verify_token(token) | |
| if not username: | |
| return jsonify({"error": "Invalid token"}), 401 | |
| conn = sqlite3.connect(DATABASE_PATH) | |
| cursor = conn.cursor() | |
| cursor.execute('SELECT category, COUNT(*) FROM documents GROUP BY category') | |
| category_stats = dict(cursor.fetchall()) | |
| cursor.execute('SELECT COUNT(*) FROM documents') | |
| total_documents = cursor.fetchone()[0] | |
| conn.close() | |
| return jsonify({ | |
| "total_categories": len(set(labels)), | |
| "total_documents": total_documents, | |
| "category_distribution": category_stats | |
| }) | |
| def get_document_preview(document_id): | |
| # Verify token | |
| auth_header = request.headers.get('Authorization') | |
| if not auth_header or not auth_header.startswith('Bearer '): | |
| # For image requests, try to get token from query params as fallback | |
| token = request.args.get('token') | |
| if not token: | |
| return jsonify({"error": "Missing or invalid token"}), 401 | |
| username = verify_token(token) | |
| else: | |
| token = auth_header.split(' ')[1] | |
| username = verify_token(token) | |
| if not username: | |
| return jsonify({"error": "Invalid token"}), 401 | |
| try: | |
| conn = sqlite3.connect(DATABASE_PATH) | |
| cursor = conn.cursor() | |
| cursor.execute('SELECT file_path FROM documents WHERE id = ?', (document_id,)) | |
| result = cursor.fetchone() | |
| conn.close() | |
| if not result: | |
| return jsonify({"error": "Document not found"}), 404 | |
| file_path = result[0] | |
| if not os.path.exists(file_path): | |
| return jsonify({"error": "File not found"}), 404 | |
| return send_from_directory(os.path.dirname(file_path), os.path.basename(file_path)) | |
| except Exception as e: | |
| return jsonify({"error": str(e)}), 500 | |
| if __name__ == "__main__": | |
| app.run(host="0.0.0.0", port=7860, debug=True) | |