Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Depends, status | |
from fastapi.responses import HTMLResponse, JSONResponse | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
import pytesseract | |
from PIL import Image | |
import numpy as np | |
import faiss | |
import os | |
import pickle | |
from pdf2image import convert_from_bytes | |
import torch | |
import clip | |
import io | |
import json | |
import uuid | |
from datetime import datetime, timedelta | |
from typing import List, Dict, Any, Optional | |
import base64 | |
import jwt | |
from passlib.context import CryptContext | |
app = FastAPI(title="Handwritten Archive Document Digitalization System") | |
# Security configuration | |
SECRET_KEY = "your-secret-key-change-this-in-production" | |
ALGORITHM = "HS256" | |
ACCESS_TOKEN_EXPIRE_MINUTES = 30 | |
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") | |
security = HTTPBearer() | |
# Default admin user (change in production) | |
USERS_DB = { | |
"admin": { | |
"username": "admin", | |
"hashed_password": pwd_context.hash("admin123"), | |
"is_active": True | |
} | |
} | |
# Mount static files | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
# --- Load or Initialize Model/Index --- | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
clip_model, preprocess = clip.load("ViT-B/32", device=device) | |
INDEX_PATH = "data/index.faiss" | |
LABELS_PATH = "data/labels.pkl" | |
DOCUMENTS_PATH = "data/documents.json" | |
UPLOADS_DIR = "data/uploads" | |
# Ensure directories exist | |
os.makedirs("data", exist_ok=True) | |
os.makedirs("static", exist_ok=True) | |
os.makedirs(UPLOADS_DIR, exist_ok=True) | |
# Initialize index and labels with error handling | |
index = faiss.IndexFlatL2(512) | |
labels = [] | |
documents = [] | |
if os.path.exists(INDEX_PATH) and os.path.exists(LABELS_PATH): | |
try: | |
index = faiss.read_index(INDEX_PATH) | |
with open(LABELS_PATH, "rb") as f: | |
labels = pickle.load(f) | |
print(f"β Loaded existing index with {len(labels)} labels") | |
except (RuntimeError, EOFError, pickle.UnpicklingError) as e: | |
print(f"β οΈ Failed to load existing index: {e}") | |
print("π Starting with fresh index") | |
if os.path.exists(INDEX_PATH): | |
os.remove(INDEX_PATH) | |
if os.path.exists(LABELS_PATH): | |
os.remove(LABELS_PATH) | |
# Load documents database | |
if os.path.exists(DOCUMENTS_PATH): | |
try: | |
with open(DOCUMENTS_PATH, 'r') as f: | |
documents = json.load(f) | |
except: | |
documents = [] | |
# Authentication functions | |
def verify_password(plain_password, hashed_password): | |
return pwd_context.verify(plain_password, hashed_password) | |
def get_password_hash(password): | |
return pwd_context.hash(password) | |
def authenticate_user(username: str, password: str): | |
user = USERS_DB.get(username) | |
if not user or not verify_password(password, user["hashed_password"]): | |
return False | |
return user | |
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): | |
to_encode = data.copy() | |
if expires_delta: | |
expire = datetime.utcnow() + expires_delta | |
else: | |
expire = datetime.utcnow() + timedelta(minutes=15) | |
to_encode.update({"exp": expire}) | |
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) | |
return encoded_jwt | |
async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)): | |
credentials_exception = HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail="Could not validate credentials", | |
headers={"WWW-Authenticate": "Bearer"}, | |
) | |
try: | |
payload = jwt.decode(credentials.credentials, SECRET_KEY, algorithms=[ALGORITHM]) | |
username: str = payload.get("sub") | |
if username is None: | |
raise credentials_exception | |
except jwt.PyJWTError: | |
raise credentials_exception | |
user = USERS_DB.get(username) | |
if user is None: | |
raise credentials_exception | |
return user | |
# --- Utilities --- | |
def save_index(): | |
try: | |
os.makedirs("data", exist_ok=True) | |
faiss.write_index(index, INDEX_PATH) | |
with open(LABELS_PATH, "wb") as f: | |
pickle.dump(labels, f) | |
except Exception as e: | |
print(f"β Failed to save index: {e}") | |
def save_documents(): | |
try: | |
with open(DOCUMENTS_PATH, 'w') as f: | |
json.dump(documents, f, indent=2) | |
except Exception as e: | |
print(f"β Failed to save documents: {e}") | |
def image_from_pdf(pdf_bytes): | |
try: | |
images = convert_from_bytes(pdf_bytes, dpi=200) | |
return images[0] | |
except Exception as e: | |
print(f"β PDF conversion error: {e}") | |
return None | |
def extract_text(image): | |
try: | |
if image is None: | |
return "β No image provided" | |
if isinstance(image, bytes): | |
image = Image.open(io.BytesIO(image)) | |
elif not isinstance(image, Image.Image): | |
image = Image.fromarray(image) | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
custom_config = r'--oem 3 --psm 6' | |
text = pytesseract.image_to_string(image, config=custom_config) | |
return text.strip() if text.strip() else "β No text detected" | |
except Exception as e: | |
return f"β OCR error: {str(e)}" | |
def get_clip_embedding(image): | |
try: | |
if image is None: | |
return None | |
if isinstance(image, bytes): | |
image = Image.open(io.BytesIO(image)) | |
elif not isinstance(image, Image.Image): | |
image = Image.fromarray(image) | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
image_input = preprocess(image).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
image_features = clip_model.encode_image(image_input) | |
image_features = image_features / image_features.norm(dim=-1, keepdim=True) | |
return image_features.cpu().numpy()[0] | |
except Exception as e: | |
print(f"β CLIP embedding error: {e}") | |
return None | |
def save_uploaded_file(file_content: bytes, filename: str) -> str: | |
file_id = str(uuid.uuid4()) | |
file_extension = os.path.splitext(filename)[1] | |
saved_filename = f"{file_id}{file_extension}" | |
file_path = os.path.join(UPLOADS_DIR, saved_filename) | |
with open(file_path, 'wb') as f: | |
f.write(file_content) | |
return saved_filename | |
# --- API Endpoints --- | |
async def dashboard(): | |
with open("static/index.html", "r") as f: | |
return HTMLResponse(content=f.read()) | |
async def login(username: str = Form(...), password: str = Form(...)): | |
user = authenticate_user(username, password) | |
if not user: | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail="Incorrect username or password" | |
) | |
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) | |
access_token = create_access_token( | |
data={"sub": user["username"]}, expires_delta=access_token_expires | |
) | |
return {"access_token": access_token, "token_type": "bearer", "username": user["username"]} | |
async def upload_category( | |
file: UploadFile = File(...), | |
label: str = Form(...), | |
current_user: dict = Depends(get_current_user) | |
): | |
try: | |
if not label or not label.strip(): | |
raise HTTPException(status_code=400, detail="Please provide a label") | |
label = label.strip() | |
file_content = await file.read() | |
if file.content_type and file.content_type.startswith('application/pdf'): | |
image = image_from_pdf(file_content) | |
else: | |
image = Image.open(io.BytesIO(file_content)) | |
if image is None: | |
raise HTTPException(status_code=400, detail="Failed to process image") | |
embedding = get_clip_embedding(image) | |
if embedding is None: | |
raise HTTPException(status_code=400, detail="Failed to generate embedding") | |
index.add(np.array([embedding])) | |
labels.append(label) | |
save_index() | |
return {"message": f"β Added category '{label}' (Total: {len(labels)} categories)", "status": "success"} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def classify_document( | |
file: UploadFile = File(...), | |
current_user: dict = Depends(get_current_user) | |
): | |
try: | |
if len(labels) == 0: | |
raise HTTPException(status_code=400, detail="No categories in database. Please add some first.") | |
file_content = await file.read() | |
if file.content_type and file.content_type.startswith('application/pdf'): | |
image = image_from_pdf(file_content) | |
else: | |
image = Image.open(io.BytesIO(file_content)) | |
if image is None: | |
raise HTTPException(status_code=400, detail="Failed to process image") | |
embedding = get_clip_embedding(image) | |
if embedding is None: | |
raise HTTPException(status_code=400, detail="Failed to generate embedding") | |
# Search for top 3 matches | |
k = min(3, len(labels)) | |
D, I = index.search(np.array([embedding]), k=k) | |
if len(labels) > 0 and I[0][0] < len(labels): | |
similarity = 1 - D[0][0] | |
confidence_threshold = 0.35 | |
best_match = labels[I[0][0]] | |
matches = [] | |
for i in range(min(k, len(D[0]))): | |
if I[0][i] < len(labels): | |
sim = 1 - D[0][i] | |
matches.append({"category": labels[I[0][i]], "similarity": round(sim, 3)}) | |
# Save classified document | |
if similarity >= confidence_threshold: | |
saved_filename = save_uploaded_file(file_content, file.filename) | |
ocr_text = extract_text(image) | |
document = { | |
"id": str(uuid.uuid4()), | |
"filename": saved_filename, | |
"original_filename": file.filename, | |
"category": best_match, | |
"similarity": round(similarity, 3), | |
"ocr_text": ocr_text, | |
"upload_date": datetime.now().isoformat(), | |
"file_path": os.path.join(UPLOADS_DIR, saved_filename) | |
} | |
documents.append(document) | |
save_documents() | |
return { | |
"status": "success", | |
"category": best_match, | |
"similarity": round(similarity, 3), | |
"confidence": "high" if similarity >= confidence_threshold else "low", | |
"matches": matches, | |
"document_saved": True, | |
"document_id": document["id"] | |
} | |
else: | |
return { | |
"status": "low_confidence", | |
"category": best_match, | |
"similarity": round(similarity, 3), | |
"confidence": "low", | |
"matches": matches, | |
"document_saved": False | |
} | |
raise HTTPException(status_code=400, detail="Document not recognized") | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def get_categories(current_user: dict = Depends(get_current_user)): | |
categories = list(set(labels)) # Remove duplicates | |
category_counts = {} | |
for label in labels: | |
category_counts[label] = category_counts.get(label, 0) + 1 | |
return {"categories": categories, "counts": category_counts} | |
async def get_documents_by_category( | |
category: str, | |
current_user: dict = Depends(get_current_user) | |
): | |
category_documents = [doc for doc in documents if doc["category"] == category] | |
return {"documents": category_documents, "count": len(category_documents)} | |
async def get_all_documents(current_user: dict = Depends(get_current_user)): | |
return {"documents": documents, "count": len(documents)} | |
async def delete_document( | |
document_id: str, | |
current_user: dict = Depends(get_current_user) | |
): | |
try: | |
# Find document | |
document_index = None | |
document_to_delete = None | |
for i, doc in enumerate(documents): | |
if doc["id"] == document_id: | |
document_index = i | |
document_to_delete = doc | |
break | |
if document_to_delete is None: | |
raise HTTPException(status_code=404, detail="Document not found") | |
# Delete physical file | |
file_path = document_to_delete.get("file_path") | |
if file_path and os.path.exists(file_path): | |
os.remove(file_path) | |
# Remove from documents list | |
documents.pop(document_index) | |
save_documents() | |
return {"message": "Document deleted successfully", "status": "success"} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def ocr_document( | |
file: UploadFile = File(...), | |
current_user: dict = Depends(get_current_user) | |
): | |
try: | |
file_content = await file.read() | |
if file.content_type and file.content_type.startswith('application/pdf'): | |
image = image_from_pdf(file_content) | |
else: | |
image = Image.open(io.BytesIO(file_content)) | |
if image is None: | |
raise HTTPException(status_code=400, detail="Failed to process image") | |
text = extract_text(image) | |
return {"text": text, "status": "success"} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def get_stats(current_user: dict = Depends(get_current_user)): | |
category_stats = {} | |
for doc in documents: | |
category = doc["category"] | |
if category not in category_stats: | |
category_stats[category] = 0 | |
category_stats[category] += 1 | |
return { | |
"total_categories": len(set(labels)), | |
"total_documents": len(documents), | |
"category_distribution": category_stats | |
} | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |