File size: 1,925 Bytes
c2d58b3
 
 
 
d2102d0
 
c2d58b3
 
16be0ef
d2102d0
 
 
 
c2d58b3
d2102d0
 
c2d58b3
d2102d0
c2d58b3
 
d2102d0
c2d58b3
49e3b35
c2d58b3
 
 
d2102d0
c2d58b3
d2102d0
c2d58b3
 
d2102d0
c2d58b3
 
 
d2102d0
c2d58b3
 
 
 
 
 
 
 
 
 
 
 
 
d2102d0
c2d58b3
 
d2102d0
c2d58b3
d2102d0
 
c2d58b3
49e3b35
c2d58b3
d2102d0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import re
import time
import torch
from transformers import DonutProcessor, VisionEncoderDecoderModel
from config import settings
from functools import lru_cache
import os


@lru_cache(maxsize=1)
def load_model():
    processor = DonutProcessor.from_pretrained(settings.processor)
    model = VisionEncoderDecoderModel.from_pretrained(settings.model)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)

    return processor, model, device


def process_document_donut(image):
    worker_pid = os.getpid()
    print(f"Handling inference request with worker PID: {worker_pid}")

    start_time = time.time()

    processor, model, device = load_model()

    # prepare encoder inputs
    pixel_values = processor(image, return_tensors="pt").pixel_values

    # prepare decoder inputs
    task_prompt = "<s_cord-v2>"
    decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids

    # generate answer
    outputs = model.generate(
        pixel_values.to(device),
        decoder_input_ids=decoder_input_ids.to(device),
        max_length=model.decoder.config.max_position_embeddings,
        early_stopping=True,
        pad_token_id=processor.tokenizer.pad_token_id,
        eos_token_id=processor.tokenizer.eos_token_id,
        use_cache=True,
        num_beams=1,
        bad_words_ids=[[processor.tokenizer.unk_token_id]],
        return_dict_in_generate=True,
    )

    # postprocess
    sequence = processor.batch_decode(outputs.sequences)[0]
    sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
    sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()  # remove first task start token

    end_time = time.time()
    processing_time = end_time - start_time

    print(f"Inference done, worker PID: {worker_pid}")

    return processor.token2json(sequence), processing_time