handwritten / app.py
IZERE HIRWA Roger
po
5753ed4
raw
history blame
16.7 kB
from flask import Flask, request, jsonify, send_from_directory
from werkzeug.utils import secure_filename
from werkzeug.security import generate_password_hash, check_password_hash
import pytesseract
from PIL import Image
import numpy as np
import faiss
import os
import pickle
from pdf2image import convert_from_bytes
import torch
import clip
import io
import json
import uuid
from datetime import datetime, timedelta
import jwt
import sqlite3
import tempfile
app = Flask(__name__)
app.config['SECRET_KEY'] = 'your-secret-key-change-this-in-production'
# Security configuration
SECRET_KEY = "your-secret-key-change-this-in-production"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
# Set CLIP cache to writable directory
os.environ['CLIP_CACHE'] = '/app/clip_cache'
os.makedirs('/app/clip_cache', exist_ok=True)
# Directories
INDEX_PATH = "data/index.faiss"
LABELS_PATH = "data/labels.pkl"
DATABASE_PATH = "data/documents.db"
UPLOADS_DIR = "data/uploads"
os.makedirs("data", exist_ok=True)
os.makedirs("static", exist_ok=True)
os.makedirs(UPLOADS_DIR, exist_ok=True)
# Initialize database
def init_db():
conn = sqlite3.connect(DATABASE_PATH)
cursor = conn.cursor()
# Users table
cursor.execute('''
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT UNIQUE NOT NULL,
password_hash TEXT NOT NULL,
is_active BOOLEAN DEFAULT TRUE
)
''')
# Documents table
cursor.execute('''
CREATE TABLE IF NOT EXISTS documents (
id TEXT PRIMARY KEY,
filename TEXT NOT NULL,
original_filename TEXT NOT NULL,
category TEXT NOT NULL,
similarity REAL NOT NULL,
ocr_text TEXT,
upload_date TEXT NOT NULL,
file_path TEXT NOT NULL
)
''')
# Insert default admin user if not exists
cursor.execute('SELECT * FROM users WHERE username = ?', ('admin',))
if not cursor.fetchone():
admin_hash = generate_password_hash('admin123')
cursor.execute('INSERT INTO users (username, password_hash) VALUES (?, ?)',
('admin', admin_hash))
conn.commit()
conn.close()
init_db()
# Initialize index and labels
index = faiss.IndexFlatL2(512)
labels = []
if os.path.exists(INDEX_PATH) and os.path.exists(LABELS_PATH):
try:
index = faiss.read_index(INDEX_PATH)
with open(LABELS_PATH, "rb") as f:
labels = pickle.load(f)
print(f"βœ… Loaded existing index with {len(labels)} labels")
except Exception as e:
print(f"⚠️ Failed to load existing index: {e}")
if os.path.exists(INDEX_PATH):
os.remove(INDEX_PATH)
if os.path.exists(LABELS_PATH):
os.remove(LABELS_PATH)
# Initialize CLIP model with custom cache
device = "cuda" if torch.cuda.is_available() else "cpu"
try:
clip_model, preprocess = clip.load("ViT-B/32", device=device, download_root='/app/clip_cache')
print("βœ… CLIP model loaded successfully")
except Exception as e:
print(f"❌ Failed to load CLIP model: {e}")
# Fallback initialization
clip_model = None
preprocess = None
# Helper functions
def save_index():
try:
faiss.write_index(index, INDEX_PATH)
with open(LABELS_PATH, "wb") as f:
pickle.dump(labels, f)
except Exception as e:
print(f"❌ Failed to save index: {e}")
def authenticate_user(username: str, password: str):
conn = sqlite3.connect(DATABASE_PATH)
cursor = conn.cursor()
cursor.execute('SELECT password_hash FROM users WHERE username = ? AND is_active = TRUE', (username,))
result = cursor.fetchone()
conn.close()
if result and check_password_hash(result[0], password):
return {"username": username}
return None
def create_access_token(data: dict):
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode = data.copy()
to_encode.update({"exp": expire})
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
def verify_token(token: str):
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username = payload.get("sub")
return username if username else None
except jwt.PyJWTError:
return None
def image_from_pdf(pdf_bytes):
try:
images = convert_from_bytes(pdf_bytes, dpi=200)
return images[0]
except Exception as e:
print(f"❌ PDF conversion error: {e}")
return None
def extract_text(image):
try:
if image.mode != 'RGB':
image = image.convert('RGB')
custom_config = r'--oem 3 --psm 6'
text = pytesseract.image_to_string(image, config=custom_config)
return text.strip() if text.strip() else "❓ No text detected"
except Exception as e:
return f"❌ OCR error: {str(e)}"
def get_clip_embedding(image):
try:
if clip_model is None:
return None
if image.mode != 'RGB':
image = image.convert('RGB')
image_input = preprocess(image).unsqueeze(0).to(device)
with torch.no_grad():
image_features = clip_model.encode_image(image_input)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
return image_features.cpu().numpy()[0]
except Exception as e:
print(f"❌ CLIP embedding error: {e}")
return None
def save_uploaded_file(file_content: bytes, filename: str) -> str:
file_id = str(uuid.uuid4())
file_extension = os.path.splitext(filename)[1]
saved_filename = f"{file_id}{file_extension}"
file_path = os.path.join(UPLOADS_DIR, saved_filename)
with open(file_path, 'wb') as f:
f.write(file_content)
return saved_filename
# Routes
@app.route("/")
def dashboard():
return send_from_directory('static', 'index.html')
@app.route("/static/<path:filename>")
def static_files(filename):
return send_from_directory('static', filename)
@app.route("/api/login", methods=["POST"])
def login():
username = request.form.get("username")
password = request.form.get("password")
user = authenticate_user(username, password)
if not user:
return jsonify({"detail": "Incorrect username or password"}), 401
access_token = create_access_token(data={"sub": user["username"]})
return jsonify({"access_token": access_token, "token_type": "bearer", "username": user["username"]})
@app.route("/api/upload-category", methods=["POST"])
def upload_category():
# Verify token
auth_header = request.headers.get('Authorization')
if not auth_header or not auth_header.startswith('Bearer '):
return jsonify({"error": "Missing or invalid token"}), 401
token = auth_header.split(' ')[1]
username = verify_token(token)
if not username:
return jsonify({"error": "Invalid token"}), 401
try:
label = request.form.get("label")
file = request.files.get("file")
if not label or not file:
return jsonify({"error": "Missing label or file"}), 400
file_content = file.read()
if file.content_type and file.content_type.startswith('application/pdf'):
image = image_from_pdf(file_content)
else:
image = Image.open(io.BytesIO(file_content))
if image is None:
return jsonify({"error": "Failed to process image"}), 400
embedding = get_clip_embedding(image)
if embedding is None:
return jsonify({"error": "Failed to generate embedding"}), 400
index.add(np.array([embedding]))
labels.append(label.strip())
save_index()
return jsonify({"message": f"βœ… Added category '{label}' (Total: {len(labels)} categories)", "status": "success"})
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route("/api/classify-document", methods=["POST"])
def classify_document():
# Verify token
auth_header = request.headers.get('Authorization')
if not auth_header or not auth_header.startswith('Bearer '):
return jsonify({"error": "Missing or invalid token"}), 401
token = auth_header.split(' ')[1]
username = verify_token(token)
if not username:
return jsonify({"error": "Invalid token"}), 401
try:
if len(labels) == 0:
return jsonify({"error": "No categories in database. Please add some first."}), 400
file = request.files.get("file")
if not file:
return jsonify({"error": "Missing file"}), 400
file_content = file.read()
if file.content_type and file.content_type.startswith('application/pdf'):
image = image_from_pdf(file_content)
else:
image = Image.open(io.BytesIO(file_content))
if image is None:
return jsonify({"error": "Failed to process image"}), 400
embedding = get_clip_embedding(image)
if embedding is None:
return jsonify({"error": "Failed to generate embedding"}), 400
k = min(3, len(labels))
D, I = index.search(np.array([embedding]), k=k)
if len(labels) > 0 and I[0][0] < len(labels):
similarity = 1 - D[0][0]
confidence_threshold = 0.35
best_match = labels[I[0][0]]
matches = []
for i in range(min(k, len(D[0]))):
if I[0][i] < len(labels):
sim = 1 - D[0][i]
matches.append({"category": labels[I[0][i]], "similarity": round(sim, 3)})
# Save classified document to SQLite
if similarity >= confidence_threshold:
saved_filename = save_uploaded_file(file_content, file.filename)
ocr_text = extract_text(image)
document_id = str(uuid.uuid4())
conn = sqlite3.connect(DATABASE_PATH)
cursor = conn.cursor()
cursor.execute('''
INSERT INTO documents (id, filename, original_filename, category, similarity, ocr_text, upload_date, file_path)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
''', (document_id, saved_filename, file.filename, best_match, round(similarity, 3),
ocr_text, datetime.now().isoformat(), os.path.join(UPLOADS_DIR, saved_filename)))
conn.commit()
conn.close()
return jsonify({
"status": "success",
"category": best_match,
"similarity": round(similarity, 3),
"confidence": "high",
"matches": matches,
"document_saved": True,
"document_id": document_id
})
else:
return jsonify({
"status": "low_confidence",
"category": best_match,
"similarity": round(similarity, 3),
"confidence": "low",
"matches": matches,
"document_saved": False
})
return jsonify({"error": "Document not recognized"}), 400
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route("/api/categories", methods=["GET"])
def get_categories():
# Verify token
auth_header = request.headers.get('Authorization')
if not auth_header or not auth_header.startswith('Bearer '):
return jsonify({"error": "Missing or invalid token"}), 401
token = auth_header.split(' ')[1]
username = verify_token(token)
if not username:
return jsonify({"error": "Invalid token"}), 401
categories = list(set(labels)) # Remove duplicates
category_counts = {}
for label in labels:
category_counts[label] = category_counts.get(label, 0) + 1
return jsonify({"categories": categories, "counts": category_counts})
@app.route("/api/documents/<category>", methods=["GET"])
def get_documents_by_category(category):
# Verify token
auth_header = request.headers.get('Authorization')
if not auth_header or not auth_header.startswith('Bearer '):
return jsonify({"error": "Missing or invalid token"}), 401
token = auth_header.split(' ')[1]
username = verify_token(token)
if not username:
return jsonify({"error": "Invalid token"}), 401
conn = sqlite3.connect(DATABASE_PATH)
cursor = conn.cursor()
cursor.execute('SELECT * FROM documents WHERE category = ? ORDER BY upload_date DESC', (category,))
documents = []
for row in cursor.fetchall():
documents.append({
"id": row[0],
"filename": row[1],
"original_filename": row[2],
"category": row[3],
"similarity": row[4],
"ocr_text": row[5],
"upload_date": row[6],
"file_path": row[7]
})
conn.close()
return jsonify({"documents": documents, "count": len(documents)})
@app.route("/api/documents/<document_id>", methods=["DELETE"])
def delete_document(document_id):
# Verify token
auth_header = request.headers.get('Authorization')
if not auth_header or not auth_header.startswith('Bearer '):
return jsonify({"error": "Missing or invalid token"}), 401
token = auth_header.split(' ')[1]
username = verify_token(token)
if not username:
return jsonify({"error": "Invalid token"}), 401
try:
conn = sqlite3.connect(DATABASE_PATH)
cursor = conn.cursor()
# Get document info first
cursor.execute('SELECT file_path FROM documents WHERE id = ?', (document_id,))
result = cursor.fetchone()
if not result:
conn.close()
return jsonify({"error": "Document not found"}), 404
file_path = result[0]
# Delete physical file
if file_path and os.path.exists(file_path):
os.remove(file_path)
# Delete from database
cursor.execute('DELETE FROM documents WHERE id = ?', (document_id,))
conn.commit()
conn.close()
return jsonify({"message": "Document deleted successfully", "status": "success"})
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route("/api/ocr", methods=["POST"])
def ocr_document():
# Verify token
auth_header = request.headers.get('Authorization')
if not auth_header or not auth_header.startswith('Bearer '):
return jsonify({"error": "Missing or invalid token"}), 401
token = auth_header.split(' ')[1]
username = verify_token(token)
if not username:
return jsonify({"error": "Invalid token"}), 401
try:
file = request.files.get("file")
if not file:
return jsonify({"error": "Missing file"}), 400
file_content = file.read()
if file.content_type and file.content_type.startswith('application/pdf'):
image = image_from_pdf(file_content)
else:
image = Image.open(io.BytesIO(file_content))
if image is None:
return jsonify({"error": "Failed to process image"}), 400
text = extract_text(image)
return jsonify({"text": text, "status": "success"})
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route("/api/stats", methods=["GET"])
def get_stats():
# Verify token
auth_header = request.headers.get('Authorization')
if not auth_header or not auth_header.startswith('Bearer '):
return jsonify({"error": "Missing or invalid token"}), 401
token = auth_header.split(' ')[1]
username = verify_token(token)
if not username:
return jsonify({"error": "Invalid token"}), 401
conn = sqlite3.connect(DATABASE_PATH)
cursor = conn.cursor()
cursor.execute('SELECT category, COUNT(*) FROM documents GROUP BY category')
category_stats = dict(cursor.fetchall())
cursor.execute('SELECT COUNT(*) FROM documents')
total_documents = cursor.fetchone()[0]
conn.close()
return jsonify({
"total_categories": len(set(labels)),
"total_documents": total_documents,
"category_distribution": category_stats
})
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860, debug=True)