senga-dnotes / config.py
serenarolloh's picture
Update config.py
461c1a0 verified
raw
history blame
2.46 kB
from pydantic import BaseSettings, Field
class Settings(BaseSettings):
# Default values for the processor and model
processor: str = Field(default="senga-ml/donut-v16")
model: str = Field(default="senga-ml/donut-v16")
dataset: str = Field(default="senga-ml/dnotes-data-v6")
base_config: str = Field(default="naver-clova-ix/donut-base")
base_processor: str = Field(default="naver-clova-ix/donut-base")
base_model: str = Field(default="naver-clova-ix/donut-base")
inference_stats_file: str = Field(default="data/donut_inference_stats.json")
training_stats_file: str = Field(default="data/donut_training_stats.json")
evaluate_stats_file: str = Field(default="data/donut_evaluate_stats.json")
# The shipper_id to dynamically select model and processor
shipper_id: str = Field(default="default_shipper")
class Config:
# This enables the automatic reloading of values when they change
validate_assignment = True
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.set_model()
# Function to dynamically select model and processor based on shipper_id
def set_model(self):
shipper_model_map = {
"61": {"model": "senga-ml/donut-16", "processor": "senga-ml/donut-v16"},
"81": {"model": "senga-ml/donut-16", "processor": "senga-ml/donut-v16"},
"139": {"model": "senga-ml/donut-16", "processor": "senga-ml/donut-v16"},
"165": {"model": "senga-ml/donut-17", "processor": "senga-ml/donut-v17"},
"127": {"model": "senga-ml/donut-17", "processor": "senga-ml/donut-v17"},
"145": {"model": "senga-ml/donut-17", "processor": "senga-ml/donut-v17"},
}
config = shipper_model_map.get(
self.shipper_id,
{"model": self.base_model, "processor": self.base_processor}
)
self.model = config["model"]
self.processor = config["processor"]
# For debugging
print(f"Selected model for shipper {self.shipper_id}: {self.model}")
print(f"Selected processor for shipper {self.shipper_id}: {self.processor}")
# Create a singleton instance
settings = Settings()
# Example of how to update shipper_id and trigger model change
def update_shipper(new_shipper_id):
settings.shipper_id = new_shipper_id
settings.set_model()
return settings.model, settings.processor