Spaces:
Sleeping
Sleeping
File size: 3,078 Bytes
4382401 cbbd0e6 4382401 8cf136d 4382401 c05980d 969487e c05980d c2d58b3 8cf136d a63a07c 8cf136d 4382401 c2d58b3 4382401 969487e 4382401 969487e 4382401 c2d58b3 4382401 969487e 4382401 969487e 4382401 969487e 4382401 969487e 4382401 969487e 4382401 969487e 4382401 969487e 4382401 969487e 4382401 a63a07c 1a23ce8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
from fastapi import FastAPI, APIRouter, File, UploadFile, Form
from fastapi.middleware.cors import CORSMiddleware
from typing import Optional
from PIL import Image
import urllib.request
from io import BytesIO
import json
from config import settings
import utils
from routers import inference, training
from routers.donut_inference import process_document_donut
from huggingface_hub import login
import os
# Login using Hugging Face token from environment
login(os.getenv("HUGGINGFACE_KEY"))
app = FastAPI(openapi_url="/api/v1/sparrow-ml/openapi.json", docs_url="/api/v1/sparrow-ml/docs")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
allow_credentials=True,
)
app.include_router(inference.router, prefix="/api-inference/v1/sparrow-ml", tags=["Inference"])
app.include_router(training.router, prefix="/api-training/v1/sparrow-ml", tags=["Training"])
router = APIRouter()
def count_values(obj):
if isinstance(obj, dict):
return sum(count_values(v) for v in obj.values())
elif isinstance(obj, list):
return sum(count_values(i) for i in obj)
else:
return 1
@router.post("/inference")
async def run_inference(
file: Optional[UploadFile] = File(None),
image_url: Optional[str] = Form(None),
shipper_id: int = Form(...),
model_in_use: str = Form('donut')
):
result = []
# Dynamically select model
model_url = settings.get_model_url(shipper_id)
model_name = model_url.replace("https://huggingface.co/spaces/", "")
print(f"[DEBUG] Using model: {model_name}")
if file:
if file.content_type not in ["image/jpeg", "image/jpg"]:
return {"error": "Invalid file type. Only JPG images are allowed."}
image = Image.open(BytesIO(await file.read()))
result, processing_time = process_document_donut(image, model_url)
utils.log_stats(settings.inference_stats_file, [processing_time, count_values(result), file.filename, model_name])
print(f"Processing time: {processing_time:.2f} seconds")
elif image_url:
with urllib.request.urlopen(image_url) as url:
image = Image.open(BytesIO(url.read()))
result, processing_time = process_document_donut(image, model_url)
file_name = image_url.split("/")[-1]
utils.log_stats(settings.inference_stats_file, [processing_time, count_values(result), file_name, model_name])
print(f"Processing time inference: {processing_time:.2f} seconds")
else:
result = {"info": "No input provided"}
return result
@router.get("/statistics")
async def get_statistics():
file_path = settings.inference_stats_file
if os.path.exists(file_path):
with open(file_path, 'r') as file:
try:
content = json.load(file)
except json.JSONDecodeError:
content = []
else:
content = []
return content
@app.get("/")
async def root():
return {"message": "Senga delivery notes inferencing"} |