serenarolloh commited on
Commit
969487e
·
verified ·
1 Parent(s): 73918cc

Update endpoints.py

Browse files
Files changed (1) hide show
  1. endpoints.py +24 -29
endpoints.py CHANGED
@@ -11,7 +11,8 @@ from routers import inference, training
11
  from routers.donut_inference import process_document_donut
12
  from huggingface_hub import login
13
  import os
14
- # login(settings.huggingface_key)
 
15
  login(os.getenv("HUGGINGFACE_KEY"))
16
 
17
  app = FastAPI(openapi_url="/api/v1/sparrow-ml/openapi.json", docs_url="/api/v1/sparrow-ml/docs")
@@ -31,47 +32,44 @@ router = APIRouter()
31
 
32
  def count_values(obj):
33
  if isinstance(obj, dict):
34
- count = 0
35
- for value in obj.values():
36
- count += count_values(value)
37
- return count
38
  elif isinstance(obj, list):
39
- count = 0
40
- for item in obj:
41
- count += count_values(item)
42
- return count
43
  else:
44
  return 1
45
 
46
  @router.post("/inference")
47
- async def run_inference(file: Optional[UploadFile] = File(None), image_url: Optional[str] = Form(None),
48
- shipper_id: int = Form(...), model_in_use: str = Form('donut')):
49
-
 
 
 
50
  result = []
51
- model_url = settings.get_model_url(shipper_id) # Get the correct model URL based on shipper_id
52
-
 
 
 
 
53
  if file:
54
- # Ensure the uploaded file is a JPG image
55
  if file.content_type not in ["image/jpeg", "image/jpg"]:
56
  return {"error": "Invalid file type. Only JPG images are allowed."}
57
-
58
  image = Image.open(BytesIO(await file.read()))
59
- processing_time = 0
60
- if model_in_use == 'donut':
61
- result, processing_time = process_document_donut(image, model_url) # Pass model_url to the function
62
- utils.log_stats(settings.inference_stats_file, [processing_time, count_values(result), file.filename, settings.model])
63
  print(f"Processing time: {processing_time:.2f} seconds")
 
64
  elif image_url:
65
- # test image url: https://raw.githubusercontent.com/katanaml/sparrow/main/sparrow-data/docs/input/invoices/processed/images/invoice_10.jpg
66
  with urllib.request.urlopen(image_url) as url:
67
  image = Image.open(BytesIO(url.read()))
68
-
69
- processing_time = 0
70
- if model_in_use == 'donut':
71
- result, processing_time = process_document_donut(image, model_url)
72
  file_name = image_url.split("/")[-1]
73
- utils.log_stats(settings.inference_stats_file, [processing_time, count_values(result), file_name, settings.model])
74
  print(f"Processing time inference: {processing_time:.2f} seconds")
 
75
  else:
76
  result = {"info": "No input provided"}
77
 
@@ -80,8 +78,6 @@ async def run_inference(file: Optional[UploadFile] = File(None), image_url: Opti
80
  @router.get("/statistics")
81
  async def get_statistics():
82
  file_path = settings.inference_stats_file
83
-
84
- # Check if the file exists, and read its content
85
  if os.path.exists(file_path):
86
  with open(file_path, 'r') as file:
87
  try:
@@ -90,7 +86,6 @@ async def get_statistics():
90
  content = []
91
  else:
92
  content = []
93
-
94
  return content
95
 
96
  @app.get("/")
 
11
  from routers.donut_inference import process_document_donut
12
  from huggingface_hub import login
13
  import os
14
+
15
+ # Login using Hugging Face token from environment
16
  login(os.getenv("HUGGINGFACE_KEY"))
17
 
18
  app = FastAPI(openapi_url="/api/v1/sparrow-ml/openapi.json", docs_url="/api/v1/sparrow-ml/docs")
 
32
 
33
  def count_values(obj):
34
  if isinstance(obj, dict):
35
+ return sum(count_values(v) for v in obj.values())
 
 
 
36
  elif isinstance(obj, list):
37
+ return sum(count_values(i) for i in obj)
 
 
 
38
  else:
39
  return 1
40
 
41
  @router.post("/inference")
42
+ async def run_inference(
43
+ file: Optional[UploadFile] = File(None),
44
+ image_url: Optional[str] = Form(None),
45
+ shipper_id: int = Form(...),
46
+ model_in_use: str = Form('donut')
47
+ ):
48
  result = []
49
+
50
+ # Dynamically select model
51
+ model_url = settings.get_model_url(shipper_id)
52
+ model_name = model_url.replace("https://huggingface.co/spaces/", "")
53
+ print(f"[DEBUG] Using model: {model_name}")
54
+
55
  if file:
 
56
  if file.content_type not in ["image/jpeg", "image/jpg"]:
57
  return {"error": "Invalid file type. Only JPG images are allowed."}
58
+
59
  image = Image.open(BytesIO(await file.read()))
60
+ result, processing_time = process_document_donut(image, model_url)
61
+ utils.log_stats(settings.inference_stats_file, [processing_time, count_values(result), file.filename, model_name])
 
 
62
  print(f"Processing time: {processing_time:.2f} seconds")
63
+
64
  elif image_url:
 
65
  with urllib.request.urlopen(image_url) as url:
66
  image = Image.open(BytesIO(url.read()))
67
+
68
+ result, processing_time = process_document_donut(image, model_url)
 
 
69
  file_name = image_url.split("/")[-1]
70
+ utils.log_stats(settings.inference_stats_file, [processing_time, count_values(result), file_name, model_name])
71
  print(f"Processing time inference: {processing_time:.2f} seconds")
72
+
73
  else:
74
  result = {"info": "No input provided"}
75
 
 
78
  @router.get("/statistics")
79
  async def get_statistics():
80
  file_path = settings.inference_stats_file
 
 
81
  if os.path.exists(file_path):
82
  with open(file_path, 'r') as file:
83
  try:
 
86
  content = []
87
  else:
88
  content = []
 
89
  return content
90
 
91
  @app.get("/")