serenarolloh commited on
Commit
173bbcf
·
verified ·
1 Parent(s): 7729ce9

Update routers/inference.py

Browse files
Files changed (1) hide show
  1. routers/inference.py +18 -26
routers/inference.py CHANGED
@@ -3,67 +3,59 @@ from typing import Optional
3
  from PIL import Image
4
  import urllib.request
5
  from io import BytesIO
6
- from config import settings
7
  import utils
8
  import os
9
  import json
 
10
  from routers.donut_inference import process_document_donut
11
 
12
-
13
  router = APIRouter()
14
 
15
  def count_values(obj):
16
  if isinstance(obj, dict):
17
- count = 0
18
- for value in obj.values():
19
- count += count_values(value)
20
- return count
21
  elif isinstance(obj, list):
22
- count = 0
23
- for item in obj:
24
- count += count_values(item)
25
- return count
26
  else:
27
  return 1
28
 
29
-
30
  @router.post("/inference")
31
  async def run_inference(
32
  file: Optional[UploadFile] = File(None),
33
  image_url: Optional[str] = Form(None),
34
- shipper_id: str = Form(...),
35
- model_in_use: Optional[str] = Form(None)
36
  ):
37
-
 
 
38
  result = []
 
 
39
  if file:
40
- # Ensure the uploaded file is a JPG image
41
  if file.content_type not in ["image/jpeg", "image/jpg"]:
42
  return {"error": "Invalid file type. Only JPG images are allowed."}
43
 
44
  image = Image.open(BytesIO(await file.read()))
45
- processing_time = 0
46
  if model_in_use == 'donut':
47
- result, processing_time = process_document_donut(image)
48
  utils.log_stats(settings.inference_stats_file, [processing_time, count_values(result), file.filename, settings.model])
49
- print(f"Processing time: {processing_time:.2f} seconds")
50
  elif image_url:
51
- # test image url: https://raw.githubusercontent.com/katanaml/sparrow/main/sparrow-data/docs/input/invoices/processed/images/invoice_10.jpg
52
  with urllib.request.urlopen(image_url) as url:
53
  image = Image.open(BytesIO(url.read()))
54
-
55
- processing_time = 0
56
  if model_in_use == 'donut':
57
- result, processing_time = process_document_donut(image)
58
- # parse file name from url
59
  file_name = image_url.split("/")[-1]
60
  utils.log_stats(settings.inference_stats_file, [processing_time, count_values(result), file_name, settings.model])
61
- print(f"Processing time inference: {processing_time:.2f} seconds")
62
  else:
63
  result = {"info": "No input provided"}
64
 
65
- return result
66
-
 
 
 
 
67
 
68
  @router.get("/statistics")
69
  async def get_statistics():
 
3
  from PIL import Image
4
  import urllib.request
5
  from io import BytesIO
 
6
  import utils
7
  import os
8
  import json
9
+ from config import Settings
10
  from routers.donut_inference import process_document_donut
11
 
 
12
  router = APIRouter()
13
 
14
  def count_values(obj):
15
  if isinstance(obj, dict):
16
+ return sum(count_values(v) for v in obj.values())
 
 
 
17
  elif isinstance(obj, list):
18
+ return sum(count_values(i) for i in obj)
 
 
 
19
  else:
20
  return 1
21
 
 
22
  @router.post("/inference")
23
  async def run_inference(
24
  file: Optional[UploadFile] = File(None),
25
  image_url: Optional[str] = Form(None),
26
+ model_in_use: str = Form('donut'),
27
+ shipper_id: str = Form(...)
28
  ):
29
+ # Dynamically load config based on shipper ID
30
+ settings = Settings(shipper_id=shipper_id)
31
+
32
  result = []
33
+ processing_time = 0
34
+
35
  if file:
 
36
  if file.content_type not in ["image/jpeg", "image/jpg"]:
37
  return {"error": "Invalid file type. Only JPG images are allowed."}
38
 
39
  image = Image.open(BytesIO(await file.read()))
 
40
  if model_in_use == 'donut':
41
+ result, processing_time = process_document_donut(image, settings)
42
  utils.log_stats(settings.inference_stats_file, [processing_time, count_values(result), file.filename, settings.model])
 
43
  elif image_url:
 
44
  with urllib.request.urlopen(image_url) as url:
45
  image = Image.open(BytesIO(url.read()))
 
 
46
  if model_in_use == 'donut':
47
+ result, processing_time = process_document_donut(image, settings)
 
48
  file_name = image_url.split("/")[-1]
49
  utils.log_stats(settings.inference_stats_file, [processing_time, count_values(result), file_name, settings.model])
 
50
  else:
51
  result = {"info": "No input provided"}
52
 
53
+ return {
54
+ "shipper_id": shipper_id,
55
+ "model": settings.model,
56
+ "processor": settings.processor,
57
+ "result": result
58
+ }
59
 
60
  @router.get("/statistics")
61
  async def get_statistics():