Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Depends, status | |
| from fastapi.responses import HTMLResponse, JSONResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
| import pytesseract | |
| from PIL import Image | |
| import numpy as np | |
| import faiss | |
| import os | |
| import pickle | |
| from pdf2image import convert_from_bytes | |
| import torch | |
| import clip | |
| import io | |
| import json | |
| import uuid | |
| from datetime import datetime, timedelta | |
| from typing import List, Dict, Any, Optional | |
| import base64 | |
| import jwt | |
| from passlib.context import CryptContext | |
| app = FastAPI(title="Handwritten Archive Document Digitalization System") | |
| # Security configuration | |
| SECRET_KEY = "your-secret-key-change-this-in-production" | |
| ALGORITHM = "HS256" | |
| ACCESS_TOKEN_EXPIRE_MINUTES = 30 | |
| pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") | |
| security = HTTPBearer() | |
| # Default admin user (change in production) | |
| USERS_DB = { | |
| "admin": { | |
| "username": "admin", | |
| "hashed_password": pwd_context.hash("admin123"), | |
| "is_active": True | |
| } | |
| } | |
| # Mount static files | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| # --- Load or Initialize Model/Index --- | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| clip_model, preprocess = clip.load("ViT-B/32", device=device) | |
| INDEX_PATH = "data/index.faiss" | |
| LABELS_PATH = "data/labels.pkl" | |
| DOCUMENTS_PATH = "data/documents.json" | |
| UPLOADS_DIR = "data/uploads" | |
| # Ensure directories exist | |
| os.makedirs("data", exist_ok=True) | |
| os.makedirs("static", exist_ok=True) | |
| os.makedirs(UPLOADS_DIR, exist_ok=True) | |
| # Initialize index and labels with error handling | |
| index = faiss.IndexFlatL2(512) | |
| labels = [] | |
| documents = [] | |
| if os.path.exists(INDEX_PATH) and os.path.exists(LABELS_PATH): | |
| try: | |
| index = faiss.read_index(INDEX_PATH) | |
| with open(LABELS_PATH, "rb") as f: | |
| labels = pickle.load(f) | |
| print(f"β Loaded existing index with {len(labels)} labels") | |
| except (RuntimeError, EOFError, pickle.UnpicklingError) as e: | |
| print(f"β οΈ Failed to load existing index: {e}") | |
| print("π Starting with fresh index") | |
| if os.path.exists(INDEX_PATH): | |
| os.remove(INDEX_PATH) | |
| if os.path.exists(LABELS_PATH): | |
| os.remove(LABELS_PATH) | |
| # Load documents database | |
| if os.path.exists(DOCUMENTS_PATH): | |
| try: | |
| with open(DOCUMENTS_PATH, 'r') as f: | |
| documents = json.load(f) | |
| except: | |
| documents = [] | |
| # Authentication functions | |
| def verify_password(plain_password, hashed_password): | |
| return pwd_context.verify(plain_password, hashed_password) | |
| def get_password_hash(password): | |
| return pwd_context.hash(password) | |
| def authenticate_user(username: str, password: str): | |
| user = USERS_DB.get(username) | |
| if not user or not verify_password(password, user["hashed_password"]): | |
| return False | |
| return user | |
| def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): | |
| to_encode = data.copy() | |
| if expires_delta: | |
| expire = datetime.utcnow() + expires_delta | |
| else: | |
| expire = datetime.utcnow() + timedelta(minutes=15) | |
| to_encode.update({"exp": expire}) | |
| encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) | |
| return encoded_jwt | |
| async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)): | |
| credentials_exception = HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Could not validate credentials", | |
| headers={"WWW-Authenticate": "Bearer"}, | |
| ) | |
| try: | |
| payload = jwt.decode(credentials.credentials, SECRET_KEY, algorithms=[ALGORITHM]) | |
| username: str = payload.get("sub") | |
| if username is None: | |
| raise credentials_exception | |
| except jwt.PyJWTError: | |
| raise credentials_exception | |
| user = USERS_DB.get(username) | |
| if user is None: | |
| raise credentials_exception | |
| return user | |
| # --- Utilities --- | |
| def save_index(): | |
| try: | |
| os.makedirs("data", exist_ok=True) | |
| faiss.write_index(index, INDEX_PATH) | |
| with open(LABELS_PATH, "wb") as f: | |
| pickle.dump(labels, f) | |
| except Exception as e: | |
| print(f"β Failed to save index: {e}") | |
| def save_documents(): | |
| try: | |
| with open(DOCUMENTS_PATH, 'w') as f: | |
| json.dump(documents, f, indent=2) | |
| except Exception as e: | |
| print(f"β Failed to save documents: {e}") | |
| def image_from_pdf(pdf_bytes): | |
| try: | |
| images = convert_from_bytes(pdf_bytes, dpi=200) | |
| return images[0] | |
| except Exception as e: | |
| print(f"β PDF conversion error: {e}") | |
| return None | |
| def extract_text(image): | |
| try: | |
| if image is None: | |
| return "β No image provided" | |
| if isinstance(image, bytes): | |
| image = Image.open(io.BytesIO(image)) | |
| elif not isinstance(image, Image.Image): | |
| image = Image.fromarray(image) | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| custom_config = r'--oem 3 --psm 6' | |
| text = pytesseract.image_to_string(image, config=custom_config) | |
| return text.strip() if text.strip() else "β No text detected" | |
| except Exception as e: | |
| return f"β OCR error: {str(e)}" | |
| def get_clip_embedding(image): | |
| try: | |
| if image is None: | |
| return None | |
| if isinstance(image, bytes): | |
| image = Image.open(io.BytesIO(image)) | |
| elif not isinstance(image, Image.Image): | |
| image = Image.fromarray(image) | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| image_input = preprocess(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| image_features = clip_model.encode_image(image_input) | |
| image_features = image_features / image_features.norm(dim=-1, keepdim=True) | |
| return image_features.cpu().numpy()[0] | |
| except Exception as e: | |
| print(f"β CLIP embedding error: {e}") | |
| return None | |
| def save_uploaded_file(file_content: bytes, filename: str) -> str: | |
| file_id = str(uuid.uuid4()) | |
| file_extension = os.path.splitext(filename)[1] | |
| saved_filename = f"{file_id}{file_extension}" | |
| file_path = os.path.join(UPLOADS_DIR, saved_filename) | |
| with open(file_path, 'wb') as f: | |
| f.write(file_content) | |
| return saved_filename | |
| # --- API Endpoints --- | |
| async def dashboard(): | |
| with open("static/index.html", "r") as f: | |
| return HTMLResponse(content=f.read()) | |
| async def login(username: str = Form(...), password: str = Form(...)): | |
| user = authenticate_user(username, password) | |
| if not user: | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Incorrect username or password" | |
| ) | |
| access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) | |
| access_token = create_access_token( | |
| data={"sub": user["username"]}, expires_delta=access_token_expires | |
| ) | |
| return {"access_token": access_token, "token_type": "bearer", "username": user["username"]} | |
| async def upload_category( | |
| file: UploadFile = File(...), | |
| label: str = Form(...), | |
| current_user: dict = Depends(get_current_user) | |
| ): | |
| try: | |
| if not label or not label.strip(): | |
| raise HTTPException(status_code=400, detail="Please provide a label") | |
| label = label.strip() | |
| file_content = await file.read() | |
| if file.content_type and file.content_type.startswith('application/pdf'): | |
| image = image_from_pdf(file_content) | |
| else: | |
| image = Image.open(io.BytesIO(file_content)) | |
| if image is None: | |
| raise HTTPException(status_code=400, detail="Failed to process image") | |
| embedding = get_clip_embedding(image) | |
| if embedding is None: | |
| raise HTTPException(status_code=400, detail="Failed to generate embedding") | |
| index.add(np.array([embedding])) | |
| labels.append(label) | |
| save_index() | |
| return {"message": f"β Added category '{label}' (Total: {len(labels)} categories)", "status": "success"} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def classify_document( | |
| file: UploadFile = File(...), | |
| current_user: dict = Depends(get_current_user) | |
| ): | |
| try: | |
| if len(labels) == 0: | |
| raise HTTPException(status_code=400, detail="No categories in database. Please add some first.") | |
| file_content = await file.read() | |
| if file.content_type and file.content_type.startswith('application/pdf'): | |
| image = image_from_pdf(file_content) | |
| else: | |
| image = Image.open(io.BytesIO(file_content)) | |
| if image is None: | |
| raise HTTPException(status_code=400, detail="Failed to process image") | |
| embedding = get_clip_embedding(image) | |
| if embedding is None: | |
| raise HTTPException(status_code=400, detail="Failed to generate embedding") | |
| # Search for top 3 matches | |
| k = min(3, len(labels)) | |
| D, I = index.search(np.array([embedding]), k=k) | |
| if len(labels) > 0 and I[0][0] < len(labels): | |
| similarity = 1 - D[0][0] | |
| confidence_threshold = 0.35 | |
| best_match = labels[I[0][0]] | |
| matches = [] | |
| for i in range(min(k, len(D[0]))): | |
| if I[0][i] < len(labels): | |
| sim = 1 - D[0][i] | |
| matches.append({"category": labels[I[0][i]], "similarity": round(sim, 3)}) | |
| # Save classified document | |
| if similarity >= confidence_threshold: | |
| saved_filename = save_uploaded_file(file_content, file.filename) | |
| ocr_text = extract_text(image) | |
| document = { | |
| "id": str(uuid.uuid4()), | |
| "filename": saved_filename, | |
| "original_filename": file.filename, | |
| "category": best_match, | |
| "similarity": round(similarity, 3), | |
| "ocr_text": ocr_text, | |
| "upload_date": datetime.now().isoformat(), | |
| "file_path": os.path.join(UPLOADS_DIR, saved_filename) | |
| } | |
| documents.append(document) | |
| save_documents() | |
| return { | |
| "status": "success", | |
| "category": best_match, | |
| "similarity": round(similarity, 3), | |
| "confidence": "high" if similarity >= confidence_threshold else "low", | |
| "matches": matches, | |
| "document_saved": True, | |
| "document_id": document["id"] | |
| } | |
| else: | |
| return { | |
| "status": "low_confidence", | |
| "category": best_match, | |
| "similarity": round(similarity, 3), | |
| "confidence": "low", | |
| "matches": matches, | |
| "document_saved": False | |
| } | |
| raise HTTPException(status_code=400, detail="Document not recognized") | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def get_categories(current_user: dict = Depends(get_current_user)): | |
| categories = list(set(labels)) # Remove duplicates | |
| category_counts = {} | |
| for label in labels: | |
| category_counts[label] = category_counts.get(label, 0) + 1 | |
| return {"categories": categories, "counts": category_counts} | |
| async def get_documents_by_category( | |
| category: str, | |
| current_user: dict = Depends(get_current_user) | |
| ): | |
| category_documents = [doc for doc in documents if doc["category"] == category] | |
| return {"documents": category_documents, "count": len(category_documents)} | |
| async def get_all_documents(current_user: dict = Depends(get_current_user)): | |
| return {"documents": documents, "count": len(documents)} | |
| async def delete_document( | |
| document_id: str, | |
| current_user: dict = Depends(get_current_user) | |
| ): | |
| try: | |
| # Find document | |
| document_index = None | |
| document_to_delete = None | |
| for i, doc in enumerate(documents): | |
| if doc["id"] == document_id: | |
| document_index = i | |
| document_to_delete = doc | |
| break | |
| if document_to_delete is None: | |
| raise HTTPException(status_code=404, detail="Document not found") | |
| # Delete physical file | |
| file_path = document_to_delete.get("file_path") | |
| if file_path and os.path.exists(file_path): | |
| os.remove(file_path) | |
| # Remove from documents list | |
| documents.pop(document_index) | |
| save_documents() | |
| return {"message": "Document deleted successfully", "status": "success"} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def ocr_document( | |
| file: UploadFile = File(...), | |
| current_user: dict = Depends(get_current_user) | |
| ): | |
| try: | |
| file_content = await file.read() | |
| if file.content_type and file.content_type.startswith('application/pdf'): | |
| image = image_from_pdf(file_content) | |
| else: | |
| image = Image.open(io.BytesIO(file_content)) | |
| if image is None: | |
| raise HTTPException(status_code=400, detail="Failed to process image") | |
| text = extract_text(image) | |
| return {"text": text, "status": "success"} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def get_stats(current_user: dict = Depends(get_current_user)): | |
| category_stats = {} | |
| for doc in documents: | |
| category = doc["category"] | |
| if category not in category_stats: | |
| category_stats[category] = 0 | |
| category_stats[category] += 1 | |
| return { | |
| "total_categories": len(set(labels)), | |
| "total_documents": len(documents), | |
| "category_distribution": category_stats | |
| } | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |