serenarolloh commited on
Commit
0b1a8c8
·
verified ·
1 Parent(s): 841b35b

Update routers/donut_inference.py

Browse files
Files changed (1) hide show
  1. routers/donut_inference.py +74 -37
routers/donut_inference.py CHANGED
@@ -2,55 +2,93 @@ import re
2
  import time
3
  import torch
4
  from transformers import DonutProcessor, VisionEncoderDecoderModel
5
- from config import settings
6
  from functools import lru_cache
7
  import os
 
8
 
 
 
 
9
 
10
- @lru_cache(maxsize=10)
11
- def load_model(shipper_id: str):
12
- temp_settings = Settings(shipper_id=shipper_id)
13
- processor = DonutProcessor.from_pretrained(temp_settings.processor)
14
- model = VisionEncoderDecoderModel.from_pretrained(temp_settings.model)
15
-
16
- device = "cuda" if torch.cuda.is_available() else "cpu"
17
- model.to(device)
18
-
19
- return processor, model, device
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  def process_document_donut(image, shipper_id="default_shipper"):
 
 
 
 
 
 
 
 
 
 
23
  worker_pid = os.getpid()
24
- print(f"Handling inference request with worker PID: {worker_pid}")
25
-
26
  start_time = time.time()
27
-
 
28
  processor, model, device = load_model(shipper_id)
29
-
30
  # prepare encoder inputs
31
  pixel_values = processor(image, return_tensors="pt").pixel_values
32
-
33
  # prepare decoder inputs
34
  task_prompt = "<s_cord-v2>"
35
  decoder_input_ids = processor.tokenizer(
36
- task_prompt,
37
- add_special_tokens=False,
38
  return_tensors="pt"
39
  ).input_ids
40
-
41
- outputs = model.generate(
42
- pixel_values.to(device),
43
- decoder_input_ids=decoder_input_ids.to(device),
44
- max_length=model.decoder.config.max_position_embeddings,
45
- early_stopping=True,
46
- pad_token_id=processor.tokenizer.pad_token_id,
47
- eos_token_id=processor.tokenizer.eos_token_id,
48
- use_cache=True,
49
- num_beams=1,
50
- bad_words_ids=[[processor.tokenizer.unk_token_id]],
51
- return_dict_in_generate=True,
52
- )
53
-
54
  # generate answer
55
  outputs = model.generate(
56
  pixel_values.to(device),
@@ -64,15 +102,14 @@ def process_document_donut(image, shipper_id="default_shipper"):
64
  bad_words_ids=[[processor.tokenizer.unk_token_id]],
65
  return_dict_in_generate=True,
66
  )
67
-
68
  # postprocess
69
  sequence = processor.batch_decode(outputs.sequences)[0]
70
  sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
71
  sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
72
-
73
  end_time = time.time()
74
  processing_time = end_time - start_time
75
-
76
- print(f"Inference done, worker PID: {worker_pid}")
77
-
78
  return processor.token2json(sequence), processing_time
 
2
  import time
3
  import torch
4
  from transformers import DonutProcessor, VisionEncoderDecoderModel
5
+ from config import settings, update_shipper
6
  from functools import lru_cache
7
  import os
8
+ import logging
9
 
10
+ # Set up logging
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
 
14
+ # Model cache dictionary - maps shipper_id to loaded models
15
+ model_cache = {}
 
 
 
 
 
 
 
 
16
 
17
+ def load_model(shipper_id: str):
18
+ """
19
+ Load a model for a specific shipper_id, with caching
20
+
21
+ Args:
22
+ shipper_id: The shipper ID to load the model for
23
+
24
+ Returns:
25
+ tuple: (processor, model, device) for the specified shipper
26
+ """
27
+ # Check if this model is already loaded in cache
28
+ if shipper_id in model_cache:
29
+ logger.info(f"Using cached model for shipper {shipper_id}")
30
+ return model_cache[shipper_id]
31
+
32
+ # Update settings to use the appropriate model for this shipper
33
+ model_name, processor_name = update_shipper(shipper_id)
34
+
35
+ logger.info(f"Loading model for shipper {shipper_id}: model={model_name}, processor={processor_name}")
36
+
37
+ try:
38
+ # Load the model from HuggingFace
39
+ processor = DonutProcessor.from_pretrained(processor_name)
40
+ model = VisionEncoderDecoderModel.from_pretrained(model_name)
41
+ device = "cuda" if torch.cuda.is_available() else "cpu"
42
+ model.to(device)
43
+
44
+ # Cache the loaded model
45
+ model_cache[shipper_id] = (processor, model, device)
46
+
47
+ return processor, model, device
48
+ except Exception as e:
49
+ logger.error(f"Error loading model for shipper {shipper_id}: {str(e)}")
50
+ # Fall back to default model
51
+ if 'default_shipper' in model_cache:
52
+ logger.info("Falling back to default model")
53
+ return model_cache['default_shipper']
54
+ else:
55
+ logger.info("Loading default model")
56
+ processor = DonutProcessor.from_pretrained(settings.base_processor)
57
+ model = VisionEncoderDecoderModel.from_pretrained(settings.base_model)
58
+ device = "cuda" if torch.cuda.is_available() else "cpu"
59
+ model.to(device)
60
+ model_cache['default_shipper'] = (processor, model, device)
61
+ return model_cache['default_shipper']
62
 
63
  def process_document_donut(image, shipper_id="default_shipper"):
64
+ """
65
+ Process a document using the appropriate model for the shipper
66
+
67
+ Args:
68
+ image: The document image to process
69
+ shipper_id: Shipper ID to select a specific model
70
+
71
+ Returns:
72
+ tuple: (results, processing_time)
73
+ """
74
  worker_pid = os.getpid()
75
+ logger.info(f"Handling inference request with worker PID: {worker_pid}, shipper_id: {shipper_id}")
 
76
  start_time = time.time()
77
+
78
+ # Load the model based on shipper_id
79
  processor, model, device = load_model(shipper_id)
80
+
81
  # prepare encoder inputs
82
  pixel_values = processor(image, return_tensors="pt").pixel_values
83
+
84
  # prepare decoder inputs
85
  task_prompt = "<s_cord-v2>"
86
  decoder_input_ids = processor.tokenizer(
87
+ task_prompt,
88
+ add_special_tokens=False,
89
  return_tensors="pt"
90
  ).input_ids
91
+
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  # generate answer
93
  outputs = model.generate(
94
  pixel_values.to(device),
 
102
  bad_words_ids=[[processor.tokenizer.unk_token_id]],
103
  return_dict_in_generate=True,
104
  )
105
+
106
  # postprocess
107
  sequence = processor.batch_decode(outputs.sequences)[0]
108
  sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
109
  sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
110
+
111
  end_time = time.time()
112
  processing_time = end_time - start_time
113
+ logger.info(f"Inference done in {processing_time:.2f}s, worker PID: {worker_pid}")
114
+
 
115
  return processor.token2json(sequence), processing_time