from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Depends, status from fastapi.responses import HTMLResponse, JSONResponse from fastapi.staticfiles import StaticFiles from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials import pytesseract from PIL import Image import numpy as np import faiss import os import pickle from pdf2image import convert_from_bytes import torch import clip import io import json import uuid from datetime import datetime, timedelta from typing import List, Dict, Any, Optional import base64 import jwt from passlib.context import CryptContext app = FastAPI(title="Handwritten Archive Document Digitalization System") # Security configuration SECRET_KEY = "your-secret-key-change-this-in-production" ALGORITHM = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES = 30 pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") security = HTTPBearer() # Default admin user (change in production) USERS_DB = { "admin": { "username": "admin", "hashed_password": pwd_context.hash("admin123"), "is_active": True } } # Mount static files app.mount("/static", StaticFiles(directory="static"), name="static") # --- Load or Initialize Model/Index --- device = "cuda" if torch.cuda.is_available() else "cpu" clip_model, preprocess = clip.load("ViT-B/32", device=device) INDEX_PATH = "data/index.faiss" LABELS_PATH = "data/labels.pkl" DOCUMENTS_PATH = "data/documents.json" UPLOADS_DIR = "data/uploads" # Ensure directories exist os.makedirs("data", exist_ok=True) os.makedirs("static", exist_ok=True) os.makedirs(UPLOADS_DIR, exist_ok=True) # Initialize index and labels with error handling index = faiss.IndexFlatL2(512) labels = [] documents = [] if os.path.exists(INDEX_PATH) and os.path.exists(LABELS_PATH): try: index = faiss.read_index(INDEX_PATH) with open(LABELS_PATH, "rb") as f: labels = pickle.load(f) print(f"✅ Loaded existing index with {len(labels)} labels") except (RuntimeError, EOFError, pickle.UnpicklingError) as e: print(f"⚠️ Failed to load existing index: {e}") print("🔄 Starting with fresh index") if os.path.exists(INDEX_PATH): os.remove(INDEX_PATH) if os.path.exists(LABELS_PATH): os.remove(LABELS_PATH) # Load documents database if os.path.exists(DOCUMENTS_PATH): try: with open(DOCUMENTS_PATH, 'r') as f: documents = json.load(f) except: documents = [] # Authentication functions def verify_password(plain_password, hashed_password): return pwd_context.verify(plain_password, hashed_password) def get_password_hash(password): return pwd_context.hash(password) def authenticate_user(username: str, password: str): user = USERS_DB.get(username) if not user or not verify_password(password, user["hashed_password"]): return False return user def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): to_encode = data.copy() if expires_delta: expire = datetime.utcnow() + expires_delta else: expire = datetime.utcnow() + timedelta(minutes=15) to_encode.update({"exp": expire}) encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) return encoded_jwt async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)): credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"}, ) try: payload = jwt.decode(credentials.credentials, SECRET_KEY, algorithms=[ALGORITHM]) username: str = payload.get("sub") if username is None: raise credentials_exception except jwt.PyJWTError: raise credentials_exception user = USERS_DB.get(username) if user is None: raise credentials_exception return user # --- Utilities --- def save_index(): try: os.makedirs("data", exist_ok=True) faiss.write_index(index, INDEX_PATH) with open(LABELS_PATH, "wb") as f: pickle.dump(labels, f) except Exception as e: print(f"❌ Failed to save index: {e}") def save_documents(): try: with open(DOCUMENTS_PATH, 'w') as f: json.dump(documents, f, indent=2) except Exception as e: print(f"❌ Failed to save documents: {e}") def image_from_pdf(pdf_bytes): try: images = convert_from_bytes(pdf_bytes, dpi=200) return images[0] except Exception as e: print(f"❌ PDF conversion error: {e}") return None def extract_text(image): try: if image is None: return "❌ No image provided" if isinstance(image, bytes): image = Image.open(io.BytesIO(image)) elif not isinstance(image, Image.Image): image = Image.fromarray(image) if image.mode != 'RGB': image = image.convert('RGB') custom_config = r'--oem 3 --psm 6' text = pytesseract.image_to_string(image, config=custom_config) return text.strip() if text.strip() else "❓ No text detected" except Exception as e: return f"❌ OCR error: {str(e)}" def get_clip_embedding(image): try: if image is None: return None if isinstance(image, bytes): image = Image.open(io.BytesIO(image)) elif not isinstance(image, Image.Image): image = Image.fromarray(image) if image.mode != 'RGB': image = image.convert('RGB') image_input = preprocess(image).unsqueeze(0).to(device) with torch.no_grad(): image_features = clip_model.encode_image(image_input) image_features = image_features / image_features.norm(dim=-1, keepdim=True) return image_features.cpu().numpy()[0] except Exception as e: print(f"❌ CLIP embedding error: {e}") return None def save_uploaded_file(file_content: bytes, filename: str) -> str: file_id = str(uuid.uuid4()) file_extension = os.path.splitext(filename)[1] saved_filename = f"{file_id}{file_extension}" file_path = os.path.join(UPLOADS_DIR, saved_filename) with open(file_path, 'wb') as f: f.write(file_content) return saved_filename # --- API Endpoints --- @app.get("/", response_class=HTMLResponse) async def dashboard(): with open("static/index.html", "r") as f: return HTMLResponse(content=f.read()) @app.post("/api/login") async def login(username: str = Form(...), password: str = Form(...)): user = authenticate_user(username, password) if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect username or password" ) access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) access_token = create_access_token( data={"sub": user["username"]}, expires_delta=access_token_expires ) return {"access_token": access_token, "token_type": "bearer", "username": user["username"]} @app.post("/api/upload-category") async def upload_category( file: UploadFile = File(...), label: str = Form(...), current_user: dict = Depends(get_current_user) ): try: if not label or not label.strip(): raise HTTPException(status_code=400, detail="Please provide a label") label = label.strip() file_content = await file.read() if file.content_type and file.content_type.startswith('application/pdf'): image = image_from_pdf(file_content) else: image = Image.open(io.BytesIO(file_content)) if image is None: raise HTTPException(status_code=400, detail="Failed to process image") embedding = get_clip_embedding(image) if embedding is None: raise HTTPException(status_code=400, detail="Failed to generate embedding") index.add(np.array([embedding])) labels.append(label) save_index() return {"message": f"✅ Added category '{label}' (Total: {len(labels)} categories)", "status": "success"} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/api/classify-document") async def classify_document( file: UploadFile = File(...), current_user: dict = Depends(get_current_user) ): try: if len(labels) == 0: raise HTTPException(status_code=400, detail="No categories in database. Please add some first.") file_content = await file.read() if file.content_type and file.content_type.startswith('application/pdf'): image = image_from_pdf(file_content) else: image = Image.open(io.BytesIO(file_content)) if image is None: raise HTTPException(status_code=400, detail="Failed to process image") embedding = get_clip_embedding(image) if embedding is None: raise HTTPException(status_code=400, detail="Failed to generate embedding") # Search for top 3 matches k = min(3, len(labels)) D, I = index.search(np.array([embedding]), k=k) if len(labels) > 0 and I[0][0] < len(labels): similarity = 1 - D[0][0] confidence_threshold = 0.35 best_match = labels[I[0][0]] matches = [] for i in range(min(k, len(D[0]))): if I[0][i] < len(labels): sim = 1 - D[0][i] matches.append({"category": labels[I[0][i]], "similarity": round(sim, 3)}) # Save classified document if similarity >= confidence_threshold: saved_filename = save_uploaded_file(file_content, file.filename) ocr_text = extract_text(image) document = { "id": str(uuid.uuid4()), "filename": saved_filename, "original_filename": file.filename, "category": best_match, "similarity": round(similarity, 3), "ocr_text": ocr_text, "upload_date": datetime.now().isoformat(), "file_path": os.path.join(UPLOADS_DIR, saved_filename) } documents.append(document) save_documents() return { "status": "success", "category": best_match, "similarity": round(similarity, 3), "confidence": "high" if similarity >= confidence_threshold else "low", "matches": matches, "document_saved": True, "document_id": document["id"] } else: return { "status": "low_confidence", "category": best_match, "similarity": round(similarity, 3), "confidence": "low", "matches": matches, "document_saved": False } raise HTTPException(status_code=400, detail="Document not recognized") except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/api/categories") async def get_categories(current_user: dict = Depends(get_current_user)): categories = list(set(labels)) # Remove duplicates category_counts = {} for label in labels: category_counts[label] = category_counts.get(label, 0) + 1 return {"categories": categories, "counts": category_counts} @app.get("/api/documents/{category}") async def get_documents_by_category( category: str, current_user: dict = Depends(get_current_user) ): category_documents = [doc for doc in documents if doc["category"] == category] return {"documents": category_documents, "count": len(category_documents)} @app.get("/api/documents") async def get_all_documents(current_user: dict = Depends(get_current_user)): return {"documents": documents, "count": len(documents)} @app.delete("/api/documents/{document_id}") async def delete_document( document_id: str, current_user: dict = Depends(get_current_user) ): try: # Find document document_index = None document_to_delete = None for i, doc in enumerate(documents): if doc["id"] == document_id: document_index = i document_to_delete = doc break if document_to_delete is None: raise HTTPException(status_code=404, detail="Document not found") # Delete physical file file_path = document_to_delete.get("file_path") if file_path and os.path.exists(file_path): os.remove(file_path) # Remove from documents list documents.pop(document_index) save_documents() return {"message": "Document deleted successfully", "status": "success"} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/api/ocr") async def ocr_document( file: UploadFile = File(...), current_user: dict = Depends(get_current_user) ): try: file_content = await file.read() if file.content_type and file.content_type.startswith('application/pdf'): image = image_from_pdf(file_content) else: image = Image.open(io.BytesIO(file_content)) if image is None: raise HTTPException(status_code=400, detail="Failed to process image") text = extract_text(image) return {"text": text, "status": "success"} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/api/stats") async def get_stats(current_user: dict = Depends(get_current_user)): category_stats = {} for doc in documents: category = doc["category"] if category not in category_stats: category_stats[category] = 0 category_stats[category] += 1 return { "total_categories": len(set(labels)), "total_documents": len(documents), "category_distribution": category_stats } if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)