File size: 3,472 Bytes
4382401
cbbd0e6
4382401
 
 
 
 
 
 
8cf136d
4382401
c05980d
 
 
 
c2d58b3
8cf136d
 
a63a07c
 
 
 
 
 
 
 
8cf136d
 
 
4382401
c2d58b3
4382401
 
 
 
 
 
 
 
 
 
 
 
 
c2d58b3
4382401
 
 
c2d58b3
4382401
 
 
 
 
 
 
c2d58b3
4382401
 
 
 
 
 
 
 
 
 
 
 
 
c05980d
4382401
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a63a07c
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from fastapi import FastAPI, APIRouter, File, UploadFile, Form
from fastapi.middleware.cors import CORSMiddleware
from typing import Optional
from PIL import Image
import urllib.request
from io import BytesIO
import json
from config import settings  
import utils
from routers import inference, training
from routers.donut_inference import process_document_donut
from huggingface_hub import login
import os
# login(settings.huggingface_key)
login(os.getenv("HUGGINGFACE_KEY"))

app = FastAPI(openapi_url="/api/v1/sparrow-ml/openapi.json", docs_url="/api/v1/sparrow-ml/docs")

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
    allow_credentials=True,
)

app.include_router(inference.router, prefix="/api-inference/v1/sparrow-ml", tags=["Inference"])
app.include_router(training.router, prefix="/api-training/v1/sparrow-ml", tags=["Training"])

router = APIRouter()

def count_values(obj):
    if isinstance(obj, dict):
        count = 0
        for value in obj.values():
            count += count_values(value)
        return count
    elif isinstance(obj, list):
        count = 0
        for item in obj:
            count += count_values(item)
        return count
    else:
        return 1

@router.post("/inference")
async def run_inference(file: Optional[UploadFile] = File(None), image_url: Optional[str] = Form(None),
                        shipper_id: int = Form(...), model_in_use: str = Form('donut')):

    result = []
    model_url = settings.get_model_url(shipper_id)  # Get the correct model URL based on shipper_id
    
    if file:
        # Ensure the uploaded file is a JPG image
        if file.content_type not in ["image/jpeg", "image/jpg"]:
            return {"error": "Invalid file type. Only JPG images are allowed."}

        image = Image.open(BytesIO(await file.read()))
        processing_time = 0
        if model_in_use == 'donut':
            result, processing_time = process_document_donut(image, model_url)  # Pass model_url to the function
        utils.log_stats(settings.inference_stats_file, [processing_time, count_values(result), file.filename, settings.model])
        print(f"Processing time: {processing_time:.2f} seconds")
    elif image_url:
        # test image url: https://raw.githubusercontent.com/katanaml/sparrow/main/sparrow-data/docs/input/invoices/processed/images/invoice_10.jpg
        with urllib.request.urlopen(image_url) as url:
            image = Image.open(BytesIO(url.read()))

        processing_time = 0
        if model_in_use == 'donut':
            result, processing_time = process_document_donut(image, model_url)
        file_name = image_url.split("/")[-1]
        utils.log_stats(settings.inference_stats_file, [processing_time, count_values(result), file_name, settings.model])
        print(f"Processing time inference: {processing_time:.2f} seconds")
    else:
        result = {"info": "No input provided"}

    return result

@router.get("/statistics")
async def get_statistics():
    file_path = settings.inference_stats_file

    # Check if the file exists, and read its content
    if os.path.exists(file_path):
        with open(file_path, 'r') as file:
            try:
                content = json.load(file)
            except json.JSONDecodeError:
                content = []
    else:
        content = []

    return content

@app.get("/")
async def root():
    return {"message": "Naivas LPO inferencing"}