serenarolloh commited on
Commit
4382401
·
verified ·
1 Parent(s): 2468ba3

Update endpoints.py

Browse files
Files changed (1) hide show
  1. endpoints.py +71 -20
endpoints.py CHANGED
@@ -1,26 +1,77 @@
1
- from fastapi import FastAPI
2
- from fastapi.middleware.cors import CORSMiddleware
3
- from routers import inference, training
4
- from huggingface_hub import login
5
- from config import settings
 
6
  import os
7
- # login(settings.huggingface_key)
8
- login(os.getenv("HUGGINGFACE_KEY"))
 
9
 
10
- app = FastAPI(openapi_url="/api/v1/sparrow-ml/openapi.json", docs_url="/api/v1/sparrow-ml/docs")
11
 
12
- app.add_middleware(
13
- CORSMiddleware,
14
- allow_origins=["*"],
15
- allow_methods=["*"],
16
- allow_headers=["*"],
17
- allow_credentials=True,
18
- )
 
 
 
 
 
 
19
 
20
- app.include_router(inference.router, prefix="/api-inference/v1/sparrow-ml", tags=["Inference"])
21
- app.include_router(training.router, prefix="/api-training/v1/sparrow-ml", tags=["Training"])
 
22
 
 
 
 
 
 
 
 
23
 
24
- @app.get("/")
25
- async def root():
26
- return {"message": "Senga Dnotes Inferencing"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, APIRouter, File, UploadFile, Form
2
+ from typing import Optional
3
+ from PIL import Image
4
+ import urllib.request
5
+ from io import BytesIO
6
+ import json
7
  import os
8
+ from config import settings
9
+ import utils
10
+ from routers.donut_inference import process_document_donut
11
 
12
+ router = APIRouter()
13
 
14
+ def count_values(obj):
15
+ if isinstance(obj, dict):
16
+ count = 0
17
+ for value in obj.values():
18
+ count += count_values(value)
19
+ return count
20
+ elif isinstance(obj, list):
21
+ count = 0
22
+ for item in obj:
23
+ count += count_values(item)
24
+ return count
25
+ else:
26
+ return 1
27
 
28
+ @router.post("/inference")
29
+ async def run_inference(file: Optional[UploadFile] = File(None), image_url: Optional[str] = Form(None),
30
+ shipper_id: int = Form(...), model_in_use: str = Form('donut')):
31
 
32
+ result = []
33
+ model_url = settings.get_model_url(shipper_id) # Get the correct model URL based on shipper_id
34
+
35
+ if file:
36
+ # Ensure the uploaded file is a JPG image
37
+ if file.content_type not in ["image/jpeg", "image/jpg"]:
38
+ return {"error": "Invalid file type. Only JPG images are allowed."}
39
 
40
+ image = Image.open(BytesIO(await file.read()))
41
+ processing_time = 0
42
+ if model_in_use == 'donut':
43
+ result, processing_time = process_document_donut(image, model_url) # Pass model_url to the function
44
+ utils.log_stats(settings.inference_stats_file, [processing_time, count_values(result), file.filename, settings.model])
45
+ print(f"Processing time: {processing_time:.2f} seconds")
46
+ elif image_url:
47
+ # test image url: https://raw.githubusercontent.com/katanaml/sparrow/main/sparrow-data/docs/input/invoices/processed/images/invoice_10.jpg
48
+ with urllib.request.urlopen(image_url) as url:
49
+ image = Image.open(BytesIO(url.read()))
50
+
51
+ processing_time = 0
52
+ if model_in_use == 'donut':
53
+ result, processing_time = process_document_donut(image, model_url) # Pass model_url to the function
54
+ # parse file name from url
55
+ file_name = image_url.split("/")[-1]
56
+ utils.log_stats(settings.inference_stats_file, [processing_time, count_values(result), file_name, settings.model])
57
+ print(f"Processing time inference: {processing_time:.2f} seconds")
58
+ else:
59
+ result = {"info": "No input provided"}
60
+
61
+ return result
62
+
63
+ @router.get("/statistics")
64
+ async def get_statistics():
65
+ file_path = settings.inference_stats_file
66
+
67
+ # Check if the file exists, and read its content
68
+ if os.path.exists(file_path):
69
+ with open(file_path, 'r') as file:
70
+ try:
71
+ content = json.load(file)
72
+ except json.JSONDecodeError:
73
+ content = []
74
+ else:
75
+ content = []
76
+
77
+ return content