serenarolloh commited on
Commit
447c004
·
verified ·
1 Parent(s): 4382401

Update routers/donut_inference.py

Browse files
Files changed (1) hide show
  1. routers/donut_inference.py +28 -11
routers/donut_inference.py CHANGED
@@ -5,12 +5,21 @@ from transformers import DonutProcessor, VisionEncoderDecoderModel
5
  from config import settings
6
  from functools import lru_cache
7
  import os
8
-
9
 
10
  @lru_cache(maxsize=1)
11
- def load_model():
12
- processor = DonutProcessor.from_pretrained(settings.processor)
13
- model = VisionEncoderDecoderModel.from_pretrained(settings.model)
 
 
 
 
 
 
 
 
 
14
 
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
  model.to(device)
@@ -18,22 +27,30 @@ def load_model():
18
  return processor, model, device
19
 
20
 
21
- def process_document_donut(image):
 
 
 
 
 
 
 
22
  worker_pid = os.getpid()
23
  print(f"Handling inference request with worker PID: {worker_pid}")
24
 
25
  start_time = time.time()
26
 
27
- processor, model, device = load_model()
 
28
 
29
- # prepare encoder inputs
30
  pixel_values = processor(image, return_tensors="pt").pixel_values
31
 
32
- # prepare decoder inputs
33
  task_prompt = "<s_cord-v2>"
34
  decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
35
 
36
- # generate answer
37
  outputs = model.generate(
38
  pixel_values.to(device),
39
  decoder_input_ids=decoder_input_ids.to(device),
@@ -47,10 +64,10 @@ def process_document_donut(image):
47
  return_dict_in_generate=True,
48
  )
49
 
50
- # postprocess
51
  sequence = processor.batch_decode(outputs.sequences)[0]
52
  sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
53
- sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
54
 
55
  end_time = time.time()
56
  processing_time = end_time - start_time
 
5
  from config import settings
6
  from functools import lru_cache
7
  import os
8
+ import requests
9
 
10
  @lru_cache(maxsize=1)
11
+ def load_model(model_url: str):
12
+ """
13
+ Load the processor and model dynamically based on the model URL.
14
+
15
+ :param model_url: The URL for the model to use.
16
+ :return: The processor, model, and device.
17
+ """
18
+ # Assuming the model URL follows a pattern like "https://huggingface.co/{model_name}"
19
+ model_name = model_url.split("/")[-1] # Extract model name from the URL
20
+
21
+ processor = DonutProcessor.from_pretrained(model_name)
22
+ model = VisionEncoderDecoderModel.from_pretrained(model_name)
23
 
24
  device = "cuda" if torch.cuda.is_available() else "cpu"
25
  model.to(device)
 
27
  return processor, model, device
28
 
29
 
30
+ def process_document_donut(image, model_url: str):
31
+ """
32
+ Process the document using the DONUT model.
33
+
34
+ :param image: The input image to process.
35
+ :param model_url: The model URL to use for inference.
36
+ :return: A tuple of the result and processing time.
37
+ """
38
  worker_pid = os.getpid()
39
  print(f"Handling inference request with worker PID: {worker_pid}")
40
 
41
  start_time = time.time()
42
 
43
+ # Load the model dynamically based on the model_url
44
+ processor, model, device = load_model(model_url)
45
 
46
+ # Prepare encoder inputs
47
  pixel_values = processor(image, return_tensors="pt").pixel_values
48
 
49
+ # Prepare decoder inputs
50
  task_prompt = "<s_cord-v2>"
51
  decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
52
 
53
+ # Generate answer
54
  outputs = model.generate(
55
  pixel_values.to(device),
56
  decoder_input_ids=decoder_input_ids.to(device),
 
64
  return_dict_in_generate=True,
65
  )
66
 
67
+ # Postprocess the result
68
  sequence = processor.batch_decode(outputs.sequences)[0]
69
  sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
70
+ sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # Remove first task start token
71
 
72
  end_time = time.time()
73
  processing_time = end_time - start_time