Spaces:
Sleeping
Sleeping
File size: 4,173 Bytes
c2d58b3 0b1a8c8 d2102d0 c2d58b3 0b1a8c8 c2d58b3 0b1a8c8 16be0ef 0b1a8c8 c2d58b3 0b1a8c8 c2d58b3 0ca384b 0b1a8c8 c2d58b3 0b1a8c8 c2d58b3 0b1a8c8 792e477 0b1a8c8 d2102d0 c2d58b3 0b1a8c8 d2102d0 c2d58b3 792e477 0b1a8c8 792e477 0b1a8c8 d2102d0 c2d58b3 0b1a8c8 d2102d0 c2d58b3 d2102d0 0b1a8c8 d2102d0 0b1a8c8 d2102d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import re
import time
import torch
from transformers import DonutProcessor, VisionEncoderDecoderModel
from config import settings, update_shipper
from functools import lru_cache
import os
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Model cache dictionary - maps shipper_id to loaded models
model_cache = {}
def load_model(shipper_id: str):
"""
Load a model for a specific shipper_id, with caching
Args:
shipper_id: The shipper ID to load the model for
Returns:
tuple: (processor, model, device) for the specified shipper
"""
# Check if this model is already loaded in cache
if shipper_id in model_cache:
logger.info(f"Using cached model for shipper {shipper_id}")
return model_cache[shipper_id]
# Update settings to use the appropriate model for this shipper
model_name, processor_name = update_shipper(shipper_id)
logger.info(f"Loading model for shipper {shipper_id}: model={model_name}, processor={processor_name}")
try:
# Load the model from HuggingFace
processor = DonutProcessor.from_pretrained(processor_name)
model = VisionEncoderDecoderModel.from_pretrained(model_name)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
# Cache the loaded model
model_cache[shipper_id] = (processor, model, device)
return processor, model, device
except Exception as e:
logger.error(f"Error loading model for shipper {shipper_id}: {str(e)}")
# Fall back to default model
if 'default_shipper' in model_cache:
logger.info("Falling back to default model")
return model_cache['default_shipper']
else:
logger.info("Loading default model")
processor = DonutProcessor.from_pretrained(settings.base_processor)
model = VisionEncoderDecoderModel.from_pretrained(settings.base_model)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model_cache['default_shipper'] = (processor, model, device)
return model_cache['default_shipper']
def process_document_donut(image, shipper_id="default_shipper"):
"""
Process a document using the appropriate model for the shipper
Args:
image: The document image to process
shipper_id: Shipper ID to select a specific model
Returns:
tuple: (results, processing_time)
"""
worker_pid = os.getpid()
logger.info(f"Handling inference request with worker PID: {worker_pid}, shipper_id: {shipper_id}")
start_time = time.time()
# Load the model based on shipper_id
processor, model, device = load_model(shipper_id)
# prepare encoder inputs
pixel_values = processor(image, return_tensors="pt").pixel_values
# prepare decoder inputs
task_prompt = "<s_cord-v2>"
decoder_input_ids = processor.tokenizer(
task_prompt,
add_special_tokens=False,
return_tensors="pt"
).input_ids
# generate answer
outputs = model.generate(
pixel_values.to(device),
decoder_input_ids=decoder_input_ids.to(device),
max_length=model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
# postprocess
sequence = processor.batch_decode(outputs.sequences)[0]
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
end_time = time.time()
processing_time = end_time - start_time
logger.info(f"Inference done in {processing_time:.2f}s, worker PID: {worker_pid}")
return processor.token2json(sequence), processing_time |