Spaces:
Sleeping
Sleeping
File size: 1,925 Bytes
c2d58b3 d2102d0 c2d58b3 16be0ef d2102d0 c2d58b3 d2102d0 c2d58b3 d2102d0 c2d58b3 d2102d0 c2d58b3 49e3b35 c2d58b3 d2102d0 c2d58b3 d2102d0 c2d58b3 d2102d0 c2d58b3 d2102d0 c2d58b3 d2102d0 c2d58b3 d2102d0 c2d58b3 d2102d0 c2d58b3 49e3b35 c2d58b3 d2102d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import re
import time
import torch
from transformers import DonutProcessor, VisionEncoderDecoderModel
from config import settings
from functools import lru_cache
import os
@lru_cache(maxsize=1)
def load_model():
processor = DonutProcessor.from_pretrained(settings.processor)
model = VisionEncoderDecoderModel.from_pretrained(settings.model)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
return processor, model, device
def process_document_donut(image):
worker_pid = os.getpid()
print(f"Handling inference request with worker PID: {worker_pid}")
start_time = time.time()
processor, model, device = load_model()
# prepare encoder inputs
pixel_values = processor(image, return_tensors="pt").pixel_values
# prepare decoder inputs
task_prompt = "<s_cord-v2>"
decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
# generate answer
outputs = model.generate(
pixel_values.to(device),
decoder_input_ids=decoder_input_ids.to(device),
max_length=model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
# postprocess
sequence = processor.batch_decode(outputs.sequences)[0]
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
end_time = time.time()
processing_time = end_time - start_time
print(f"Inference done, worker PID: {worker_pid}")
return processor.token2json(sequence), processing_time |