File size: 4,173 Bytes
c2d58b3
 
 
 
0b1a8c8
d2102d0
c2d58b3
0b1a8c8
c2d58b3
0b1a8c8
 
 
16be0ef
0b1a8c8
 
c2d58b3
0b1a8c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c2d58b3
0ca384b
0b1a8c8
 
 
 
 
 
 
 
 
 
c2d58b3
0b1a8c8
c2d58b3
0b1a8c8
 
792e477
0b1a8c8
d2102d0
c2d58b3
0b1a8c8
d2102d0
c2d58b3
792e477
0b1a8c8
 
792e477
 
0b1a8c8
d2102d0
c2d58b3
 
 
 
 
 
 
 
 
 
 
 
0b1a8c8
d2102d0
c2d58b3
 
d2102d0
0b1a8c8
d2102d0
 
0b1a8c8
 
d2102d0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import re
import time
import torch
from transformers import DonutProcessor, VisionEncoderDecoderModel
from config import settings, update_shipper
from functools import lru_cache
import os
import logging

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Model cache dictionary - maps shipper_id to loaded models
model_cache = {}

def load_model(shipper_id: str):
    """
    Load a model for a specific shipper_id, with caching
    
    Args:
        shipper_id: The shipper ID to load the model for
    
    Returns:
        tuple: (processor, model, device) for the specified shipper
    """
    # Check if this model is already loaded in cache
    if shipper_id in model_cache:
        logger.info(f"Using cached model for shipper {shipper_id}")
        return model_cache[shipper_id]
    
    # Update settings to use the appropriate model for this shipper
    model_name, processor_name = update_shipper(shipper_id)
    
    logger.info(f"Loading model for shipper {shipper_id}: model={model_name}, processor={processor_name}")
    
    try:
        # Load the model from HuggingFace
        processor = DonutProcessor.from_pretrained(processor_name)
        model = VisionEncoderDecoderModel.from_pretrained(model_name)
        device = "cuda" if torch.cuda.is_available() else "cpu"
        model.to(device)
        
        # Cache the loaded model
        model_cache[shipper_id] = (processor, model, device)
        
        return processor, model, device
    except Exception as e:
        logger.error(f"Error loading model for shipper {shipper_id}: {str(e)}")
        # Fall back to default model
        if 'default_shipper' in model_cache:
            logger.info("Falling back to default model")
            return model_cache['default_shipper']
        else:
            logger.info("Loading default model")
            processor = DonutProcessor.from_pretrained(settings.base_processor)
            model = VisionEncoderDecoderModel.from_pretrained(settings.base_model)
            device = "cuda" if torch.cuda.is_available() else "cpu"
            model.to(device)
            model_cache['default_shipper'] = (processor, model, device)
            return model_cache['default_shipper']

def process_document_donut(image, shipper_id="default_shipper"):
    """
    Process a document using the appropriate model for the shipper
    
    Args:
        image: The document image to process
        shipper_id: Shipper ID to select a specific model
    
    Returns:
        tuple: (results, processing_time)
    """
    worker_pid = os.getpid()
    logger.info(f"Handling inference request with worker PID: {worker_pid}, shipper_id: {shipper_id}")
    start_time = time.time()
    
    # Load the model based on shipper_id
    processor, model, device = load_model(shipper_id)
    
    # prepare encoder inputs
    pixel_values = processor(image, return_tensors="pt").pixel_values
    
    # prepare decoder inputs
    task_prompt = "<s_cord-v2>"
    decoder_input_ids = processor.tokenizer(
        task_prompt, 
        add_special_tokens=False, 
        return_tensors="pt"
    ).input_ids
    
    # generate answer
    outputs = model.generate(
        pixel_values.to(device),
        decoder_input_ids=decoder_input_ids.to(device),
        max_length=model.decoder.config.max_position_embeddings,
        early_stopping=True,
        pad_token_id=processor.tokenizer.pad_token_id,
        eos_token_id=processor.tokenizer.eos_token_id,
        use_cache=True,
        num_beams=1,
        bad_words_ids=[[processor.tokenizer.unk_token_id]],
        return_dict_in_generate=True,
    )
    
    # postprocess
    sequence = processor.batch_decode(outputs.sequences)[0]
    sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
    sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()  # remove first task start token
    
    end_time = time.time()
    processing_time = end_time - start_time
    logger.info(f"Inference done in {processing_time:.2f}s, worker PID: {worker_pid}")
    
    return processor.token2json(sequence), processing_time