from fastapi import FastAPI, APIRouter, File, UploadFile, Form from fastapi.middleware.cors import CORSMiddleware from typing import Optional from PIL import Image import urllib.request from io import BytesIO import json from config import settings import utils from routers import inference, training from routers.donut_inference import process_document_donut from huggingface_hub import login import os # Login using Hugging Face token from environment login(os.getenv("HUGGINGFACE_KEY")) app = FastAPI(openapi_url="/api/v1/sparrow-ml/openapi.json", docs_url="/api/v1/sparrow-ml/docs") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], allow_credentials=True, ) app.include_router(inference.router, prefix="/api-inference/v1/sparrow-ml", tags=["Inference"]) app.include_router(training.router, prefix="/api-training/v1/sparrow-ml", tags=["Training"]) router = APIRouter() def count_values(obj): if isinstance(obj, dict): return sum(count_values(v) for v in obj.values()) elif isinstance(obj, list): return sum(count_values(i) for i in obj) else: return 1 @router.post("/inference") async def run_inference( file: Optional[UploadFile] = File(None), image_url: Optional[str] = Form(None), shipper_id: int = Form(...), model_in_use: str = Form('donut') ): result = [] # Dynamically select model model_url = settings.get_model_url(shipper_id) model_name = model_url.replace("https://huggingface.co/spaces/", "") print(f"[DEBUG] Using model: {model_name}") if file: if file.content_type not in ["image/jpeg", "image/jpg"]: return {"error": "Invalid file type. Only JPG images are allowed."} image = Image.open(BytesIO(await file.read())) result, processing_time = process_document_donut(image, model_url) utils.log_stats(settings.inference_stats_file, [processing_time, count_values(result), file.filename, model_name]) print(f"Processing time: {processing_time:.2f} seconds") elif image_url: with urllib.request.urlopen(image_url) as url: image = Image.open(BytesIO(url.read())) result, processing_time = process_document_donut(image, model_url) file_name = image_url.split("/")[-1] utils.log_stats(settings.inference_stats_file, [processing_time, count_values(result), file_name, model_name]) print(f"Processing time inference: {processing_time:.2f} seconds") else: result = {"info": "No input provided"} return result @router.get("/statistics") async def get_statistics(): file_path = settings.inference_stats_file if os.path.exists(file_path): with open(file_path, 'r') as file: try: content = json.load(file) except json.JSONDecodeError: content = [] else: content = [] return content @app.get("/") async def root(): return {"message": "Senga delivery notes inferencing"}