serenarolloh commited on
Commit
f1483b9
·
verified ·
1 Parent(s): 5d098a5

Update routers/inference.py

Browse files
Files changed (1) hide show
  1. routers/inference.py +27 -22
routers/inference.py CHANGED
@@ -3,59 +3,64 @@ from typing import Optional
3
  from PIL import Image
4
  import urllib.request
5
  from io import BytesIO
 
6
  import utils
7
  import os
8
  import json
9
- from config import Settings
10
  from routers.donut_inference import process_document_donut
11
 
 
12
  router = APIRouter()
13
 
14
  def count_values(obj):
15
  if isinstance(obj, dict):
16
- return sum(count_values(v) for v in obj.values())
 
 
 
17
  elif isinstance(obj, list):
18
- return sum(count_values(i) for i in obj)
 
 
 
19
  else:
20
  return 1
21
 
 
22
  @router.post("/inference")
23
- async def run_inference(
24
- file: Optional[UploadFile] = File(None),
25
- image_url: Optional[str] = Form(None),
26
- model_in_use: str = Form('donut'),
27
- shipper_id: str = Form(...)
28
- ):
29
- # Dynamically load config based on shipper ID
30
- settings = Settings(shipper_id=shipper_id)
31
 
 
32
  result = []
33
- processing_time = 0
34
-
35
  if file:
 
36
  if file.content_type not in ["image/jpeg", "image/jpg"]:
37
  return {"error": "Invalid file type. Only JPG images are allowed."}
38
 
39
  image = Image.open(BytesIO(await file.read()))
 
40
  if model_in_use == 'donut':
41
- result, processing_time = process_document_donut(image, settings)
42
  utils.log_stats(settings.inference_stats_file, [processing_time, count_values(result), file.filename, settings.model])
 
43
  elif image_url:
 
44
  with urllib.request.urlopen(image_url) as url:
45
  image = Image.open(BytesIO(url.read()))
 
 
46
  if model_in_use == 'donut':
47
- result, processing_time = process_document_donut(image, settings)
 
48
  file_name = image_url.split("/")[-1]
49
  utils.log_stats(settings.inference_stats_file, [processing_time, count_values(result), file_name, settings.model])
 
50
  else:
51
  result = {"info": "No input provided"}
52
 
53
- return {
54
- "shipper_id": shipper_id,
55
- "model": settings.model,
56
- "processor": settings.processor,
57
- "result": result
58
- }
59
 
60
  @router.get("/statistics")
61
  async def get_statistics():
@@ -71,4 +76,4 @@ async def get_statistics():
71
  else:
72
  content = []
73
 
74
- return content
 
3
  from PIL import Image
4
  import urllib.request
5
  from io import BytesIO
6
+ from config import settings
7
  import utils
8
  import os
9
  import json
 
10
  from routers.donut_inference import process_document_donut
11
 
12
+
13
  router = APIRouter()
14
 
15
  def count_values(obj):
16
  if isinstance(obj, dict):
17
+ count = 0
18
+ for value in obj.values():
19
+ count += count_values(value)
20
+ return count
21
  elif isinstance(obj, list):
22
+ count = 0
23
+ for item in obj:
24
+ count += count_values(item)
25
+ return count
26
  else:
27
  return 1
28
 
29
+
30
  @router.post("/inference")
31
+ async def run_inference(file: Optional[UploadFile] = File(None), image_url: Optional[str] = Form(None),
32
+ model_in_use: str = Form('donut')):
 
 
 
 
 
 
33
 
34
+
35
  result = []
 
 
36
  if file:
37
+ # Ensure the uploaded file is a JPG image
38
  if file.content_type not in ["image/jpeg", "image/jpg"]:
39
  return {"error": "Invalid file type. Only JPG images are allowed."}
40
 
41
  image = Image.open(BytesIO(await file.read()))
42
+ processing_time = 0
43
  if model_in_use == 'donut':
44
+ result, processing_time = process_document_donut(image)
45
  utils.log_stats(settings.inference_stats_file, [processing_time, count_values(result), file.filename, settings.model])
46
+ print(f"Processing time: {processing_time:.2f} seconds")
47
  elif image_url:
48
+ # test image url: https://raw.githubusercontent.com/katanaml/sparrow/main/sparrow-data/docs/input/invoices/processed/images/invoice_10.jpg
49
  with urllib.request.urlopen(image_url) as url:
50
  image = Image.open(BytesIO(url.read()))
51
+
52
+ processing_time = 0
53
  if model_in_use == 'donut':
54
+ result, processing_time = process_document_donut(image)
55
+ # parse file name from url
56
  file_name = image_url.split("/")[-1]
57
  utils.log_stats(settings.inference_stats_file, [processing_time, count_values(result), file_name, settings.model])
58
+ print(f"Processing time inference: {processing_time:.2f} seconds")
59
  else:
60
  result = {"info": "No input provided"}
61
 
62
+ return result
63
+
 
 
 
 
64
 
65
  @router.get("/statistics")
66
  async def get_statistics():
 
76
  else:
77
  content = []
78
 
79
+ return content