Spaces:
Sleeping
Sleeping
from flask import Flask, request, jsonify, send_from_directory | |
from werkzeug.utils import secure_filename | |
from werkzeug.security import generate_password_hash, check_password_hash | |
import pytesseract | |
from PIL import Image | |
import numpy as np | |
import faiss | |
import os | |
import pickle | |
from pdf2image import convert_from_bytes | |
import torch | |
import clip | |
import io | |
import json | |
import uuid | |
from datetime import datetime, timedelta | |
import jwt | |
import sqlite3 | |
import tempfile | |
app = Flask(__name__) | |
app.config['SECRET_KEY'] = 'your-secret-key-change-this-in-production' | |
# Security configuration | |
SECRET_KEY = "your-secret-key-change-this-in-production" | |
ALGORITHM = "HS256" | |
ACCESS_TOKEN_EXPIRE_MINUTES = 30 | |
# Set CLIP cache to writable directory | |
os.environ['CLIP_CACHE'] = '/app/clip_cache' | |
os.makedirs('/app/clip_cache', exist_ok=True) | |
# Directories | |
INDEX_PATH = "data/index.faiss" | |
LABELS_PATH = "data/labels.pkl" | |
DATABASE_PATH = "data/documents.db" | |
UPLOADS_DIR = "data/uploads" | |
os.makedirs("data", exist_ok=True) | |
os.makedirs("static", exist_ok=True) | |
os.makedirs(UPLOADS_DIR, exist_ok=True) | |
# Initialize database | |
def init_db(): | |
conn = sqlite3.connect(DATABASE_PATH) | |
cursor = conn.cursor() | |
# Users table | |
cursor.execute(''' | |
CREATE TABLE IF NOT EXISTS users ( | |
id INTEGER PRIMARY KEY AUTOINCREMENT, | |
username TEXT UNIQUE NOT NULL, | |
password_hash TEXT NOT NULL, | |
is_active BOOLEAN DEFAULT TRUE | |
) | |
''') | |
# Documents table | |
cursor.execute(''' | |
CREATE TABLE IF NOT EXISTS documents ( | |
id TEXT PRIMARY KEY, | |
filename TEXT NOT NULL, | |
original_filename TEXT NOT NULL, | |
category TEXT NOT NULL, | |
similarity REAL NOT NULL, | |
ocr_text TEXT, | |
upload_date TEXT NOT NULL, | |
file_path TEXT NOT NULL | |
) | |
''') | |
# Insert default admin user if not exists | |
cursor.execute('SELECT * FROM users WHERE username = ?', ('admin',)) | |
if not cursor.fetchone(): | |
admin_hash = generate_password_hash('admin123') | |
cursor.execute('INSERT INTO users (username, password_hash) VALUES (?, ?)', | |
('admin', admin_hash)) | |
conn.commit() | |
conn.close() | |
init_db() | |
# Initialize index and labels | |
index = faiss.IndexFlatL2(512) | |
labels = [] | |
if os.path.exists(INDEX_PATH) and os.path.exists(LABELS_PATH): | |
try: | |
index = faiss.read_index(INDEX_PATH) | |
with open(LABELS_PATH, "rb") as f: | |
labels = pickle.load(f) | |
print(f"β Loaded existing index with {len(labels)} labels") | |
except Exception as e: | |
print(f"β οΈ Failed to load existing index: {e}") | |
if os.path.exists(INDEX_PATH): | |
os.remove(INDEX_PATH) | |
if os.path.exists(LABELS_PATH): | |
os.remove(LABELS_PATH) | |
# Initialize CLIP model with custom cache | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
try: | |
clip_model, preprocess = clip.load("ViT-B/32", device=device, download_root='/app/clip_cache') | |
print("β CLIP model loaded successfully") | |
except Exception as e: | |
print(f"β Failed to load CLIP model: {e}") | |
# Fallback initialization | |
clip_model = None | |
preprocess = None | |
# Helper functions | |
def save_index(): | |
try: | |
faiss.write_index(index, INDEX_PATH) | |
with open(LABELS_PATH, "wb") as f: | |
pickle.dump(labels, f) | |
except Exception as e: | |
print(f"β Failed to save index: {e}") | |
def authenticate_user(username: str, password: str): | |
conn = sqlite3.connect(DATABASE_PATH) | |
cursor = conn.cursor() | |
cursor.execute('SELECT password_hash FROM users WHERE username = ? AND is_active = TRUE', (username,)) | |
result = cursor.fetchone() | |
conn.close() | |
if result and check_password_hash(result[0], password): | |
return {"username": username} | |
return None | |
def create_access_token(data: dict): | |
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) | |
to_encode = data.copy() | |
to_encode.update({"exp": expire}) | |
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) | |
def verify_token(token: str): | |
try: | |
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) | |
username = payload.get("sub") | |
return username if username else None | |
except jwt.PyJWTError: | |
return None | |
def image_from_pdf(pdf_bytes): | |
try: | |
images = convert_from_bytes(pdf_bytes, dpi=200) | |
return images[0] | |
except Exception as e: | |
print(f"β PDF conversion error: {e}") | |
return None | |
def extract_text(image): | |
try: | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
custom_config = r'--oem 3 --psm 6' | |
text = pytesseract.image_to_string(image, config=custom_config) | |
return text.strip() if text.strip() else "β No text detected" | |
except Exception as e: | |
return f"β OCR error: {str(e)}" | |
def get_clip_embedding(image): | |
try: | |
if clip_model is None: | |
return None | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
image_input = preprocess(image).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
image_features = clip_model.encode_image(image_input) | |
image_features = image_features / image_features.norm(dim=-1, keepdim=True) | |
return image_features.cpu().numpy()[0] | |
except Exception as e: | |
print(f"β CLIP embedding error: {e}") | |
return None | |
def save_uploaded_file(file_content: bytes, filename: str) -> str: | |
file_id = str(uuid.uuid4()) | |
file_extension = os.path.splitext(filename)[1] | |
saved_filename = f"{file_id}{file_extension}" | |
file_path = os.path.join(UPLOADS_DIR, saved_filename) | |
with open(file_path, 'wb') as f: | |
f.write(file_content) | |
return saved_filename | |
# Routes | |
def dashboard(): | |
return send_from_directory('static', 'index.html') | |
def static_files(filename): | |
return send_from_directory('static', filename) | |
def login(): | |
username = request.form.get("username") | |
password = request.form.get("password") | |
user = authenticate_user(username, password) | |
if not user: | |
return jsonify({"detail": "Incorrect username or password"}), 401 | |
access_token = create_access_token(data={"sub": user["username"]}) | |
return jsonify({"access_token": access_token, "token_type": "bearer", "username": user["username"]}) | |
def upload_category(): | |
# Verify token | |
auth_header = request.headers.get('Authorization') | |
if not auth_header or not auth_header.startswith('Bearer '): | |
return jsonify({"error": "Missing or invalid token"}), 401 | |
token = auth_header.split(' ')[1] | |
username = verify_token(token) | |
if not username: | |
return jsonify({"error": "Invalid token"}), 401 | |
try: | |
label = request.form.get("label") | |
file = request.files.get("file") | |
if not label or not file: | |
return jsonify({"error": "Missing label or file"}), 400 | |
file_content = file.read() | |
if file.content_type and file.content_type.startswith('application/pdf'): | |
image = image_from_pdf(file_content) | |
else: | |
image = Image.open(io.BytesIO(file_content)) | |
if image is None: | |
return jsonify({"error": "Failed to process image"}), 400 | |
embedding = get_clip_embedding(image) | |
if embedding is None: | |
return jsonify({"error": "Failed to generate embedding"}), 400 | |
index.add(np.array([embedding])) | |
labels.append(label.strip()) | |
save_index() | |
return jsonify({"message": f"β Added category '{label}' (Total: {len(labels)} categories)", "status": "success"}) | |
except Exception as e: | |
return jsonify({"error": str(e)}), 500 | |
def classify_document(): | |
# Verify token | |
auth_header = request.headers.get('Authorization') | |
if not auth_header or not auth_header.startswith('Bearer '): | |
return jsonify({"error": "Missing or invalid token"}), 401 | |
token = auth_header.split(' ')[1] | |
username = verify_token(token) | |
if not username: | |
return jsonify({"error": "Invalid token"}), 401 | |
try: | |
if len(labels) == 0: | |
return jsonify({"error": "No categories in database. Please add some first."}), 400 | |
file = request.files.get("file") | |
if not file: | |
return jsonify({"error": "Missing file"}), 400 | |
file_content = file.read() | |
if file.content_type and file.content_type.startswith('application/pdf'): | |
image = image_from_pdf(file_content) | |
else: | |
image = Image.open(io.BytesIO(file_content)) | |
if image is None: | |
return jsonify({"error": "Failed to process image"}), 400 | |
embedding = get_clip_embedding(image) | |
if embedding is None: | |
return jsonify({"error": "Failed to generate embedding"}), 400 | |
k = min(3, len(labels)) | |
D, I = index.search(np.array([embedding]), k=k) | |
if len(labels) > 0 and I[0][0] < len(labels): | |
similarity = 1 - D[0][0] | |
confidence_threshold = 0.35 | |
best_match = labels[I[0][0]] | |
matches = [] | |
for i in range(min(k, len(D[0]))): | |
if I[0][i] < len(labels): | |
sim = 1 - D[0][i] | |
matches.append({"category": labels[I[0][i]], "similarity": round(sim, 3)}) | |
# Save classified document to SQLite | |
if similarity >= confidence_threshold: | |
saved_filename = save_uploaded_file(file_content, file.filename) | |
ocr_text = extract_text(image) | |
document_id = str(uuid.uuid4()) | |
conn = sqlite3.connect(DATABASE_PATH) | |
cursor = conn.cursor() | |
cursor.execute(''' | |
INSERT INTO documents (id, filename, original_filename, category, similarity, ocr_text, upload_date, file_path) | |
VALUES (?, ?, ?, ?, ?, ?, ?, ?) | |
''', (document_id, saved_filename, file.filename, best_match, round(similarity, 3), | |
ocr_text, datetime.now().isoformat(), os.path.join(UPLOADS_DIR, saved_filename))) | |
conn.commit() | |
conn.close() | |
return jsonify({ | |
"status": "success", | |
"category": best_match, | |
"similarity": round(similarity, 3), | |
"confidence": "high", | |
"matches": matches, | |
"document_saved": True, | |
"document_id": document_id | |
}) | |
else: | |
return jsonify({ | |
"status": "low_confidence", | |
"category": best_match, | |
"similarity": round(similarity, 3), | |
"confidence": "low", | |
"matches": matches, | |
"document_saved": False | |
}) | |
return jsonify({"error": "Document not recognized"}), 400 | |
except Exception as e: | |
return jsonify({"error": str(e)}), 500 | |
def get_categories(): | |
# Verify token | |
auth_header = request.headers.get('Authorization') | |
if not auth_header or not auth_header.startswith('Bearer '): | |
return jsonify({"error": "Missing or invalid token"}), 401 | |
token = auth_header.split(' ')[1] | |
username = verify_token(token) | |
if not username: | |
return jsonify({"error": "Invalid token"}), 401 | |
categories = list(set(labels)) # Remove duplicates | |
category_counts = {} | |
for label in labels: | |
category_counts[label] = category_counts.get(label, 0) + 1 | |
return jsonify({"categories": categories, "counts": category_counts}) | |
def get_documents_by_category(category): | |
# Verify token | |
auth_header = request.headers.get('Authorization') | |
if not auth_header or not auth_header.startswith('Bearer '): | |
return jsonify({"error": "Missing or invalid token"}), 401 | |
token = auth_header.split(' ')[1] | |
username = verify_token(token) | |
if not username: | |
return jsonify({"error": "Invalid token"}), 401 | |
conn = sqlite3.connect(DATABASE_PATH) | |
cursor = conn.cursor() | |
cursor.execute('SELECT * FROM documents WHERE category = ? ORDER BY upload_date DESC', (category,)) | |
documents = [] | |
for row in cursor.fetchall(): | |
documents.append({ | |
"id": row[0], | |
"filename": row[1], | |
"original_filename": row[2], | |
"category": row[3], | |
"similarity": row[4], | |
"ocr_text": row[5], | |
"upload_date": row[6], | |
"file_path": row[7] | |
}) | |
conn.close() | |
return jsonify({"documents": documents, "count": len(documents)}) | |
def delete_document(document_id): | |
# Verify token | |
auth_header = request.headers.get('Authorization') | |
if not auth_header or not auth_header.startswith('Bearer '): | |
return jsonify({"error": "Missing or invalid token"}), 401 | |
token = auth_header.split(' ')[1] | |
username = verify_token(token) | |
if not username: | |
return jsonify({"error": "Invalid token"}), 401 | |
try: | |
conn = sqlite3.connect(DATABASE_PATH) | |
cursor = conn.cursor() | |
# Get document info first | |
cursor.execute('SELECT file_path FROM documents WHERE id = ?', (document_id,)) | |
result = cursor.fetchone() | |
if not result: | |
conn.close() | |
return jsonify({"error": "Document not found"}), 404 | |
file_path = result[0] | |
# Delete physical file | |
if file_path and os.path.exists(file_path): | |
os.remove(file_path) | |
# Delete from database | |
cursor.execute('DELETE FROM documents WHERE id = ?', (document_id,)) | |
conn.commit() | |
conn.close() | |
return jsonify({"message": "Document deleted successfully", "status": "success"}) | |
except Exception as e: | |
return jsonify({"error": str(e)}), 500 | |
def ocr_document(): | |
# Verify token | |
auth_header = request.headers.get('Authorization') | |
if not auth_header or not auth_header.startswith('Bearer '): | |
return jsonify({"error": "Missing or invalid token"}), 401 | |
token = auth_header.split(' ')[1] | |
username = verify_token(token) | |
if not username: | |
return jsonify({"error": "Invalid token"}), 401 | |
try: | |
file = request.files.get("file") | |
if not file: | |
return jsonify({"error": "Missing file"}), 400 | |
file_content = file.read() | |
if file.content_type and file.content_type.startswith('application/pdf'): | |
image = image_from_pdf(file_content) | |
else: | |
image = Image.open(io.BytesIO(file_content)) | |
if image is None: | |
return jsonify({"error": "Failed to process image"}), 400 | |
text = extract_text(image) | |
return jsonify({"text": text, "status": "success"}) | |
except Exception as e: | |
return jsonify({"error": str(e)}), 500 | |
def get_stats(): | |
# Verify token | |
auth_header = request.headers.get('Authorization') | |
if not auth_header or not auth_header.startswith('Bearer '): | |
return jsonify({"error": "Missing or invalid token"}), 401 | |
token = auth_header.split(' ')[1] | |
username = verify_token(token) | |
if not username: | |
return jsonify({"error": "Invalid token"}), 401 | |
conn = sqlite3.connect(DATABASE_PATH) | |
cursor = conn.cursor() | |
cursor.execute('SELECT category, COUNT(*) FROM documents GROUP BY category') | |
category_stats = dict(cursor.fetchall()) | |
cursor.execute('SELECT COUNT(*) FROM documents') | |
total_documents = cursor.fetchone()[0] | |
conn.close() | |
return jsonify({ | |
"total_categories": len(set(labels)), | |
"total_documents": total_documents, | |
"category_distribution": category_stats | |
}) | |
if __name__ == "__main__": | |
app.run(host="0.0.0.0", port=7860, debug=True) | |