serenarolloh commited on
Commit
16be0ef
·
verified ·
1 Parent(s): 1a23ce8

Update routers/donut_inference.py

Browse files
Files changed (1) hide show
  1. routers/donut_inference.py +6 -11
routers/donut_inference.py CHANGED
@@ -7,16 +7,11 @@ from functools import lru_cache
7
  import os
8
  import requests
9
 
 
10
  @lru_cache(maxsize=1)
11
  def load_model(model_url: str):
12
- """
13
- Load the processor and model dynamically based on the model URL.
14
-
15
- :param model_url: The URL for the model to use.
16
- :return: The processor, model, and device.
17
- """
18
- # Assuming the model URL follows a pattern like "https://huggingface.co/{model_name}"
19
- model_name = model_url.split("/")[-1] # Extract model name from the URL
20
 
21
  processor = DonutProcessor.from_pretrained(model_name)
22
  model = VisionEncoderDecoderModel.from_pretrained(model_name)
@@ -36,7 +31,7 @@ def process_document_donut(image, model_url: str):
36
  :return: A tuple of the result and processing time.
37
  """
38
  worker_pid = os.getpid()
39
- print(f"Handling inference request with worker PID: {worker_pid}")
40
 
41
  start_time = time.time()
42
 
@@ -72,6 +67,6 @@ def process_document_donut(image, model_url: str):
72
  end_time = time.time()
73
  processing_time = end_time - start_time
74
 
75
- print(f"Inference done, worker PID: {worker_pid}")
76
 
77
- return processor.token2json(sequence), processing_time
 
7
  import os
8
  import requests
9
 
10
+
11
  @lru_cache(maxsize=1)
12
  def load_model(model_url: str):
13
+ model_name = model_url.replace("https://huggingface.co/", "")
14
+ print(f"[Model Loader] Loading model: {model_name}")
 
 
 
 
 
 
15
 
16
  processor = DonutProcessor.from_pretrained(model_name)
17
  model = VisionEncoderDecoderModel.from_pretrained(model_name)
 
31
  :return: A tuple of the result and processing time.
32
  """
33
  worker_pid = os.getpid()
34
+ print(f"[Inference] Handling request with worker PID: {worker_pid}")
35
 
36
  start_time = time.time()
37
 
 
67
  end_time = time.time()
68
  processing_time = end_time - start_time
69
 
70
+ print(f"[Inference] Done. PID: {worker_pid} | Time taken: {processing_time:.2f} sec")
71
 
72
+ return processor.token2json(sequence), processing_time