serenarolloh commited on
Commit
30875d3
·
verified ·
1 Parent(s): 58a60b8

Update routers/inference.py

Browse files
Files changed (1) hide show
  1. routers/inference.py +63 -35
routers/inference.py CHANGED
@@ -1,4 +1,4 @@
1
- from fastapi import APIRouter, File, UploadFile, Form
2
  from typing import Optional
3
  from PIL import Image
4
  import urllib.request
@@ -15,7 +15,6 @@ import io
15
  logging.basicConfig(level=logging.INFO)
16
  logger = logging.getLogger(__name__)
17
 
18
-
19
  router = APIRouter()
20
 
21
  def count_values(obj):
@@ -32,47 +31,76 @@ def count_values(obj):
32
  else:
33
  return 1
34
 
35
-
36
  @router.post("/inference")
37
- async def run_inference(file: Optional[UploadFile] = File(None), image_url: Optional[str] = Form(None),
38
- model_in_use: str = Form('donut'), shipper_id: Optional[int] = Form(None)
 
 
 
39
  ):
40
-
 
 
 
 
 
 
 
 
 
41
 
42
  result = []
43
- if file:
44
- # Ensure the uploaded file is a JPG image
45
- if file.content_type not in ["image/jpeg", "image/jpg"]:
46
- return {"error": "Invalid file type. Only JPG images are allowed."}
47
-
48
- image = Image.open(BytesIO(await file.read()))
49
- processing_time = 0
50
- if model_in_use == 'donut':
51
- result, processing_time = process_document_donut(image, shipper_id)
52
- utils.log_stats(settings.inference_stats_file, [processing_time, count_values(result), file.filename, settings.model])
53
- print(f"Processing time: {processing_time:.2f} seconds")
54
- elif image_url:
55
- # test image url: https://raw.githubusercontent.com/katanaml/sparrow/main/sparrow-data/docs/input/invoices/processed/images/invoice_10.jpg
56
- with urllib.request.urlopen(image_url) as url:
57
- image = Image.open(BytesIO(url.read()))
58
-
59
- processing_time = 0
60
- if model_in_use == 'donut':
61
- result, processing_time = process_document_donut(image, shipper_id)
62
- # parse file name from url
63
- file_name = image_url.split("/")[-1]
64
- utils.log_stats(settings.inference_stats_file, [processing_time, count_values(result), file_name, settings.model])
65
- print(f"Processing time inference: {processing_time:.2f} seconds")
66
- else:
67
- result = {"info": "No input provided"}
68
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  return result
70
 
71
-
72
  @router.get("/statistics")
73
  async def get_statistics():
74
  file_path = settings.inference_stats_file
75
-
76
  # Check if the file exists, and read its content
77
  if os.path.exists(file_path):
78
  with open(file_path, 'r') as file:
@@ -82,5 +110,5 @@ async def get_statistics():
82
  content = []
83
  else:
84
  content = []
85
-
86
  return content
 
1
+ from fastapi import APIRouter, File, UploadFile, Form, HTTPException
2
  from typing import Optional
3
  from PIL import Image
4
  import urllib.request
 
15
  logging.basicConfig(level=logging.INFO)
16
  logger = logging.getLogger(__name__)
17
 
 
18
  router = APIRouter()
19
 
20
  def count_values(obj):
 
31
  else:
32
  return 1
33
 
 
34
  @router.post("/inference")
35
+ async def run_inference(
36
+ file: Optional[UploadFile] = File(None),
37
+ image_url: Optional[str] = Form(None),
38
+ model_in_use: str = Form('donut'),
39
+ shipper_id: Optional[int] = Form(None)
40
  ):
41
+ # Validate input
42
+ if not file and not image_url:
43
+ return {"info": "No input provided"}
44
+
45
+ # Log the shipper_id that was received
46
+ logger.info(f"Received inference request with shipper_id: {shipper_id}")
47
+
48
+ # Convert shipper_id to string if provided (config.py expects a string)
49
+ shipper_id_str = str(shipper_id) if shipper_id is not None else "default_shipper"
50
+ logger.info(f"Using shipper_id: {shipper_id_str} for model selection")
51
 
52
  result = []
53
+ processing_time = 0
54
+
55
+ try:
56
+ if file:
57
+ # Ensure the uploaded file is a JPG image
58
+ if file.content_type not in ["image/jpeg", "image/jpg"]:
59
+ logger.warning(f"Invalid file type: {file.content_type}")
60
+ return {"error": "Invalid file type. Only JPG images are allowed."}
61
+
62
+ logger.info(f"Processing file: {file.filename}")
63
+ image = Image.open(BytesIO(await file.read()))
64
+
65
+ if model_in_use == 'donut':
66
+ # Pass the shipper_id to the processing function
67
+ result, processing_time = process_document_donut(image, shipper_id_str)
68
+ utils.log_stats(settings.inference_stats_file, [processing_time, count_values(result), file.filename, settings.model])
69
+ logger.info(f"Processing time: {processing_time:.2f} seconds with model: {settings.model}")
70
+ else:
71
+ logger.warning(f"Unsupported model: {model_in_use}")
72
+ return {"error": f"Unsupported model: {model_in_use}"}
73
+
74
+ elif image_url:
75
+ logger.info(f"Processing image from URL: {image_url}")
76
+ # test image url: https://raw.githubusercontent.com/katanaml/sparrow/main/sparrow-data/docs/input/invoices/processed/images/invoice_10.jpg
77
+ try:
78
+ with urllib.request.urlopen(image_url) as url:
79
+ image = Image.open(BytesIO(url.read()))
80
+
81
+ if model_in_use == 'donut':
82
+ # Pass the shipper_id to the processing function
83
+ result, processing_time = process_document_donut(image, shipper_id_str)
84
+ # parse file name from url
85
+ file_name = image_url.split("/")[-1]
86
+ utils.log_stats(settings.inference_stats_file, [processing_time, count_values(result), file_name, settings.model])
87
+ logger.info(f"Processing time inference: {processing_time:.2f} seconds with model: {settings.model}")
88
+ else:
89
+ logger.warning(f"Unsupported model: {model_in_use}")
90
+ return {"error": f"Unsupported model: {model_in_use}"}
91
+ except Exception as e:
92
+ logger.error(f"Error processing image URL: {str(e)}")
93
+ return {"error": f"Error processing image URL: {str(e)}"}
94
+
95
+ except Exception as e:
96
+ logger.error(f"Error during inference: {str(e)}")
97
+ return {"error": f"Inference failed: {str(e)}"}
98
+
99
  return result
100
 
 
101
  @router.get("/statistics")
102
  async def get_statistics():
103
  file_path = settings.inference_stats_file
 
104
  # Check if the file exists, and read its content
105
  if os.path.exists(file_path):
106
  with open(file_path, 'r') as file:
 
110
  content = []
111
  else:
112
  content = []
113
+
114
  return content