File size: 3,406 Bytes
c84055e
504d366
 
ea2e193
58a60b8
 
c2d58b3
c84055e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
504d366
ea2e193
c84055e
504d366
58a60b8
 
 
504d366
 
58a60b8
 
 
c84055e
 
504d366
 
 
ea2e193
 
 
c84055e
ea2e193
c84055e
 
 
 
 
 
504d366
 
 
 
 
 
cadf158
ea2e193
ac89508
504d366
c84055e
504d366
c84055e
504d366
 
 
 
 
 
 
 
 
 
c84055e
504d366
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from pydantic import BaseSettings, Field
import logging

# Configure logger
logger = logging.getLogger(__name__)

class Settings(BaseSettings):
    # Default values for the processor and model
    processor: str = Field(default="senga-ml/donut-v16")
    model: str = Field(default="senga-ml/donut-v16")
    dataset: str = Field(default="senga-ml/dnotes-data-v6")
    base_config: str = Field(default="naver-clova-ix/donut-base")
    base_processor: str = Field(default="naver-clova-ix/donut-base")
    base_model: str = Field(default="naver-clova-ix/donut-base")
    inference_stats_file: str = Field(default="data/donut_inference_stats.json")
    training_stats_file: str = Field(default="data/donut_training_stats.json")
    evaluate_stats_file: str = Field(default="data/donut_evaluate_stats.json")
    
    # The shipper_id to dynamically select model and processor
    shipper_id: str = Field(default="default_shipper")
    
    class Config:
        # This enables the automatic reloading of values when they change
        validate_assignment = True
    
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.set_model()
    
    # Function to dynamically select model and processor based on shipper_id
    def set_model(self):
        # IMPORTANT: Make sure model names are consistent!
        # Map can handle both string and integer keys by converting to string
        shipper_model_map = {
            # Group 1 models - using donut-v16
            "61": {"model": "senga-ml/donut-v16", "processor": "senga-ml/donut-v16"},
            "81": {"model": "senga-ml/donut-v16", "processor": "senga-ml/donut-v16"},
            "139": {"model": "senga-ml/donut-v16", "processor": "senga-ml/donut-v16"},
            
            # Group 2 models - using donut-v17
            "165": {"model": "senga-ml/donut-v17", "processor": "senga-ml/donut-v17"},
            "127": {"model": "senga-ml/donut-v17", "processor": "senga-ml/donut-v17"},
            "145": {"model": "senga-ml/donut-v17", "processor": "senga-ml/donut-v17"},
        }
        
        previous_model = self.model
        previous_processor = self.processor
        
        # Convert shipper_id to string to handle both numeric and string values
        shipper_id_str = str(self.shipper_id)
        
        config = shipper_model_map.get(
            shipper_id_str, 
            {"model": self.base_model, "processor": self.base_processor}
        )
        
        self.model = config["model"]
        self.processor = config["processor"]
        
        # Log changes for debugging
        logger.info(f"Shipper ID set to: {self.shipper_id}")
        logger.info(f"Changed model from {previous_model} to {self.model}")
        logger.info(f"Changed processor from {previous_processor} to {self.processor}")
        
        return self.model, self.processor

# Create a single instance
settings = Settings()
logger.info(f"Initial model setup: {settings.model}")

# Function to update shipper and trigger model change
def update_shipper(new_shipper_id):
    """
    Update the shipper ID and change the model accordingly
    
    Args:
        new_shipper_id: The new shipper ID to use
        
    Returns:
        tuple: (model, processor) that were selected
    """
    logger.info(f"Updating shipper ID to {new_shipper_id}")
    settings.shipper_id = new_shipper_id
    return settings.set_model()