Spaces:
Running
Running
Update routers/inference.py
Browse files- routers/inference.py +63 -35
routers/inference.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from fastapi import APIRouter, File, UploadFile, Form
|
2 |
from typing import Optional
|
3 |
from PIL import Image
|
4 |
import urllib.request
|
@@ -15,7 +15,6 @@ import io
|
|
15 |
logging.basicConfig(level=logging.INFO)
|
16 |
logger = logging.getLogger(__name__)
|
17 |
|
18 |
-
|
19 |
router = APIRouter()
|
20 |
|
21 |
def count_values(obj):
|
@@ -32,47 +31,76 @@ def count_values(obj):
|
|
32 |
else:
|
33 |
return 1
|
34 |
|
35 |
-
|
36 |
@router.post("/inference")
|
37 |
-
async def run_inference(
|
38 |
-
|
|
|
|
|
|
|
39 |
):
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
result = []
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
return result
|
70 |
|
71 |
-
|
72 |
@router.get("/statistics")
|
73 |
async def get_statistics():
|
74 |
file_path = settings.inference_stats_file
|
75 |
-
|
76 |
# Check if the file exists, and read its content
|
77 |
if os.path.exists(file_path):
|
78 |
with open(file_path, 'r') as file:
|
@@ -82,5 +110,5 @@ async def get_statistics():
|
|
82 |
content = []
|
83 |
else:
|
84 |
content = []
|
85 |
-
|
86 |
return content
|
|
|
1 |
+
from fastapi import APIRouter, File, UploadFile, Form, HTTPException
|
2 |
from typing import Optional
|
3 |
from PIL import Image
|
4 |
import urllib.request
|
|
|
15 |
logging.basicConfig(level=logging.INFO)
|
16 |
logger = logging.getLogger(__name__)
|
17 |
|
|
|
18 |
router = APIRouter()
|
19 |
|
20 |
def count_values(obj):
|
|
|
31 |
else:
|
32 |
return 1
|
33 |
|
|
|
34 |
@router.post("/inference")
|
35 |
+
async def run_inference(
|
36 |
+
file: Optional[UploadFile] = File(None),
|
37 |
+
image_url: Optional[str] = Form(None),
|
38 |
+
model_in_use: str = Form('donut'),
|
39 |
+
shipper_id: Optional[int] = Form(None)
|
40 |
):
|
41 |
+
# Validate input
|
42 |
+
if not file and not image_url:
|
43 |
+
return {"info": "No input provided"}
|
44 |
+
|
45 |
+
# Log the shipper_id that was received
|
46 |
+
logger.info(f"Received inference request with shipper_id: {shipper_id}")
|
47 |
+
|
48 |
+
# Convert shipper_id to string if provided (config.py expects a string)
|
49 |
+
shipper_id_str = str(shipper_id) if shipper_id is not None else "default_shipper"
|
50 |
+
logger.info(f"Using shipper_id: {shipper_id_str} for model selection")
|
51 |
|
52 |
result = []
|
53 |
+
processing_time = 0
|
54 |
+
|
55 |
+
try:
|
56 |
+
if file:
|
57 |
+
# Ensure the uploaded file is a JPG image
|
58 |
+
if file.content_type not in ["image/jpeg", "image/jpg"]:
|
59 |
+
logger.warning(f"Invalid file type: {file.content_type}")
|
60 |
+
return {"error": "Invalid file type. Only JPG images are allowed."}
|
61 |
+
|
62 |
+
logger.info(f"Processing file: {file.filename}")
|
63 |
+
image = Image.open(BytesIO(await file.read()))
|
64 |
+
|
65 |
+
if model_in_use == 'donut':
|
66 |
+
# Pass the shipper_id to the processing function
|
67 |
+
result, processing_time = process_document_donut(image, shipper_id_str)
|
68 |
+
utils.log_stats(settings.inference_stats_file, [processing_time, count_values(result), file.filename, settings.model])
|
69 |
+
logger.info(f"Processing time: {processing_time:.2f} seconds with model: {settings.model}")
|
70 |
+
else:
|
71 |
+
logger.warning(f"Unsupported model: {model_in_use}")
|
72 |
+
return {"error": f"Unsupported model: {model_in_use}"}
|
73 |
+
|
74 |
+
elif image_url:
|
75 |
+
logger.info(f"Processing image from URL: {image_url}")
|
76 |
+
# test image url: https://raw.githubusercontent.com/katanaml/sparrow/main/sparrow-data/docs/input/invoices/processed/images/invoice_10.jpg
|
77 |
+
try:
|
78 |
+
with urllib.request.urlopen(image_url) as url:
|
79 |
+
image = Image.open(BytesIO(url.read()))
|
80 |
+
|
81 |
+
if model_in_use == 'donut':
|
82 |
+
# Pass the shipper_id to the processing function
|
83 |
+
result, processing_time = process_document_donut(image, shipper_id_str)
|
84 |
+
# parse file name from url
|
85 |
+
file_name = image_url.split("/")[-1]
|
86 |
+
utils.log_stats(settings.inference_stats_file, [processing_time, count_values(result), file_name, settings.model])
|
87 |
+
logger.info(f"Processing time inference: {processing_time:.2f} seconds with model: {settings.model}")
|
88 |
+
else:
|
89 |
+
logger.warning(f"Unsupported model: {model_in_use}")
|
90 |
+
return {"error": f"Unsupported model: {model_in_use}"}
|
91 |
+
except Exception as e:
|
92 |
+
logger.error(f"Error processing image URL: {str(e)}")
|
93 |
+
return {"error": f"Error processing image URL: {str(e)}"}
|
94 |
+
|
95 |
+
except Exception as e:
|
96 |
+
logger.error(f"Error during inference: {str(e)}")
|
97 |
+
return {"error": f"Inference failed: {str(e)}"}
|
98 |
+
|
99 |
return result
|
100 |
|
|
|
101 |
@router.get("/statistics")
|
102 |
async def get_statistics():
|
103 |
file_path = settings.inference_stats_file
|
|
|
104 |
# Check if the file exists, and read its content
|
105 |
if os.path.exists(file_path):
|
106 |
with open(file_path, 'r') as file:
|
|
|
110 |
content = []
|
111 |
else:
|
112 |
content = []
|
113 |
+
|
114 |
return content
|