File size: 3,305 Bytes
c84055e
504d366
 
58a60b8
 
 
 
c2d58b3
c84055e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
504d366
 
c84055e
504d366
58a60b8
 
 
504d366
 
58a60b8
 
 
c84055e
 
504d366
 
 
c84055e
 
 
 
 
 
 
 
504d366
 
 
 
 
 
cadf158
58a60b8
ac89508
504d366
c84055e
504d366
c84055e
504d366
 
 
 
 
 
 
 
 
 
c84055e
504d366
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from pydantic import BaseSettings, Field
import logging

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class Settings(BaseSettings):
    # Default values for the processor and model
    processor: str = Field(default="senga-ml/donut-v16")
    model: str = Field(default="senga-ml/donut-v16")
    dataset: str = Field(default="senga-ml/dnotes-data-v6")
    base_config: str = Field(default="naver-clova-ix/donut-base")
    base_processor: str = Field(default="naver-clova-ix/donut-base")
    base_model: str = Field(default="naver-clova-ix/donut-base")
    inference_stats_file: str = Field(default="data/donut_inference_stats.json")
    training_stats_file: str = Field(default="data/donut_training_stats.json")
    evaluate_stats_file: str = Field(default="data/donut_evaluate_stats.json")
    
    # The shipper_id to dynamically select model and processor
    shipper_id: str = Field(default="default_shipper")
    
    class Config:
        # This enables the automatic reloading of values when they change
        validate_assignment = True
    
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.set_model()
    
    # Function to dynamically select model and processor based on shipper_id
    def set_model(self):
        # IMPORTANT: Make sure model names are consistent!
        # You had "donut-16" in some places and "donut-v16" in others
        shipper_model_map = {
            # Group 1 models - using donut-v16
            "61": {"model": "senga-ml/donut-v16", "processor": "senga-ml/donut-v16"},
            "81": {"model": "senga-ml/donut-v16", "processor": "senga-ml/donut-v16"},
            "139": {"model": "senga-ml/donut-v16", "processor": "senga-ml/donut-v16"},
            
            # Group 2 models - using donut-v17
            "165": {"model": "senga-ml/donut-v17", "processor": "senga-ml/donut-v17"},
            "127": {"model": "senga-ml/donut-v17", "processor": "senga-ml/donut-v17"},
            "145": {"model": "senga-ml/donut-v17", "processor": "senga-ml/donut-v17"},
        }
        
        previous_model = self.model
        previous_processor = self.processor
        
        config = shipper_model_map.get(
            self.shipper_id, 
            {"model": self.base_model, "processor": self.base_processor}
        )
        
        self.model = config["model"]
        self.processor = config["processor"]
        
        # Log changes for debugging
        logger.info(f"Shipper ID set to: {self.shipper_id}")
        logger.info(f"Changed model from {previous_model} to {self.model}")
        logger.info(f"Changed processor from {previous_processor} to {self.processor}")
        
        return self.model, self.processor

# Create a singleton instance
settings = Settings()
logger.info(f"Initial model setup: {settings.model}")

# Function to update shipper and trigger model change
def update_shipper(new_shipper_id):
    """
    Update the shipper ID and change the model accordingly
    
    Args:
        new_shipper_id: The new shipper ID to use
        
    Returns:
        tuple: (model, processor) that were selected
    """
    logger.info(f"Updating shipper ID to {new_shipper_id}")
    settings.shipper_id = new_shipper_id
    return settings.set_model()