Spaces:
Sleeping
Sleeping
from pydantic import BaseSettings, Field | |
import logging | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class Settings(BaseSettings): | |
# Default values for the processor and model | |
processor: str = Field(default="senga-ml/donut-v16") | |
model: str = Field(default="senga-ml/donut-v16") | |
dataset: str = Field(default="senga-ml/dnotes-data-v6") | |
base_config: str = Field(default="naver-clova-ix/donut-base") | |
base_processor: str = Field(default="naver-clova-ix/donut-base") | |
base_model: str = Field(default="naver-clova-ix/donut-base") | |
inference_stats_file: str = Field(default="data/donut_inference_stats.json") | |
training_stats_file: str = Field(default="data/donut_training_stats.json") | |
evaluate_stats_file: str = Field(default="data/donut_evaluate_stats.json") | |
# The shipper_id to dynamically select model and processor | |
shipper_id: str = Field(default="default_shipper") | |
class Config: | |
# This enables the automatic reloading of values when they change | |
validate_assignment = True | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
self.set_model() | |
# Function to dynamically select model and processor based on shipper_id | |
def set_model(self): | |
# IMPORTANT: Make sure model names are consistent! | |
# You had "donut-16" in some places and "donut-v16" in others | |
shipper_model_map = { | |
# Group 1 models - using donut-v16 | |
"61": {"model": "senga-ml/donut-v16", "processor": "senga-ml/donut-v16"}, | |
"81": {"model": "senga-ml/donut-v16", "processor": "senga-ml/donut-v16"}, | |
"139": {"model": "senga-ml/donut-v16", "processor": "senga-ml/donut-v16"}, | |
# Group 2 models - using donut-v17 | |
"165": {"model": "senga-ml/donut-v17", "processor": "senga-ml/donut-v17"}, | |
"127": {"model": "senga-ml/donut-v17", "processor": "senga-ml/donut-v17"}, | |
"145": {"model": "senga-ml/donut-v17", "processor": "senga-ml/donut-v17"}, | |
} | |
previous_model = self.model | |
previous_processor = self.processor | |
config = shipper_model_map.get( | |
self.shipper_id, | |
{"model": self.base_model, "processor": self.base_processor} | |
) | |
self.model = config["model"] | |
self.processor = config["processor"] | |
# Log changes for debugging | |
logger.info(f"Shipper ID set to: {self.shipper_id}") | |
logger.info(f"Changed model from {previous_model} to {self.model}") | |
logger.info(f"Changed processor from {previous_processor} to {self.processor}") | |
return self.model, self.processor | |
# Create a singleton instance | |
settings = Settings() | |
logger.info(f"Initial model setup: {settings.model}") | |
# Function to update shipper and trigger model change | |
def update_shipper(new_shipper_id): | |
""" | |
Update the shipper ID and change the model accordingly | |
Args: | |
new_shipper_id: The new shipper ID to use | |
Returns: | |
tuple: (model, processor) that were selected | |
""" | |
logger.info(f"Updating shipper ID to {new_shipper_id}") | |
settings.shipper_id = new_shipper_id | |
return settings.set_model() |