serenarolloh commited on
Commit
792e477
·
verified ·
1 Parent(s): 7894c3d

Update routers/donut_inference.py

Browse files
Files changed (1) hide show
  1. routers/donut_inference.py +25 -7
routers/donut_inference.py CHANGED
@@ -7,10 +7,11 @@ from functools import lru_cache
7
  import os
8
 
9
 
10
- @lru_cache(maxsize=1)
11
- def load_model():
12
- processor = DonutProcessor.from_pretrained(settings.processor)
13
- model = VisionEncoderDecoderModel.from_pretrained(settings.model)
 
14
 
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
  model.to(device)
@@ -18,20 +19,37 @@ def load_model():
18
  return processor, model, device
19
 
20
 
21
- def process_document_donut(image):
22
  worker_pid = os.getpid()
23
  print(f"Handling inference request with worker PID: {worker_pid}")
24
 
25
  start_time = time.time()
26
 
27
- processor, model, device = load_model()
28
 
29
  # prepare encoder inputs
30
  pixel_values = processor(image, return_tensors="pt").pixel_values
31
 
32
  # prepare decoder inputs
33
  task_prompt = "<s_cord-v2>"
34
- decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  # generate answer
37
  outputs = model.generate(
 
7
  import os
8
 
9
 
10
+ @lru_cache(maxsize=10)
11
+ def load_model(shipper_id: str):
12
+ temp_settings = Settings(shipper_id=shipper_id)
13
+ processor = DonutProcessor.from_pretrained(temp_settings.processor)
14
+ model = VisionEncoderDecoderModel.from_pretrained(temp_settings.model)
15
 
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
  model.to(device)
 
19
  return processor, model, device
20
 
21
 
22
+ def process_document_donut(image, shipper_id="default_shipper"):
23
  worker_pid = os.getpid()
24
  print(f"Handling inference request with worker PID: {worker_pid}")
25
 
26
  start_time = time.time()
27
 
28
+ processor, model, device = load_model(shipper_id)
29
 
30
  # prepare encoder inputs
31
  pixel_values = processor(image, return_tensors="pt").pixel_values
32
 
33
  # prepare decoder inputs
34
  task_prompt = "<s_cord-v2>"
35
+ decoder_input_ids = processor.tokenizer(
36
+ task_prompt,
37
+ add_special_tokens=False,
38
+ return_tensors="pt"
39
+ ).input_ids
40
+
41
+ outputs = model.generate(
42
+ pixel_values.to(device),
43
+ decoder_input_ids=decoder_input_ids.to(device),
44
+ max_length=model.decoder.config.max_position_embeddings,
45
+ early_stopping=True,
46
+ pad_token_id=processor.tokenizer.pad_token_id,
47
+ eos_token_id=processor.tokenizer.eos_token_id,
48
+ use_cache=True,
49
+ num_beams=1,
50
+ bad_words_ids=[[processor.tokenizer.unk_token_id]],
51
+ return_dict_in_generate=True,
52
+ )
53
 
54
  # generate answer
55
  outputs = model.generate(