from typing import Any, Callable, Dict, List, Optional, cast from starfish.common.logger import get_logger from starfish.data_factory.config import NOT_COMPLETED_THRESHOLD, TASK_RUNNER_TIMEOUT from starfish.data_factory.constants import STORAGE_TYPE_LOCAL from starfish.data_factory.factory_ import Factory from starfish.data_factory.factory_wrapper import FactoryWrapper, DataFactoryProtocol, P, T from starfish.data_factory.factory_executor_manager import FactoryExecutorManager from starfish.data_factory.utils.data_class import FactoryMasterConfig from starfish.data_factory.utils.state import MutableSharedState logger = get_logger(__name__) def data_factory( storage: str = STORAGE_TYPE_LOCAL, batch_size: int = 1, target_count: int = 0, dead_queue_threshold: int = 3, max_concurrency: int = 10, initial_state_values: Optional[Dict[str, Any]] = None, on_record_complete: Optional[List[Callable]] = None, on_record_error: Optional[List[Callable]] = None, show_progress: bool = True, task_runner_timeout: int = TASK_RUNNER_TIMEOUT, job_run_stop_threshold: int = NOT_COMPLETED_THRESHOLD, ) -> Callable[[Callable[P, T]], DataFactoryProtocol[P, T]]: """Decorator for creating data processing pipelines. Args: storage: Storage backend to use ('local' or 'in_memory') batch_size: Number of records to process in each batch target_count: Target number of records to generate (0 means process all input) max_concurrency: Maximum number of concurrent tasks initial_state_values: Initial values for shared state on_record_complete: Callbacks to execute after successful record processing on_record_error: Callbacks to execute after failed record processing show_progress: Whether to display progress bar task_runner_timeout: Timeout in seconds for task execution job_run_stop_threshold: Threshold for stopping job if too many records fail Returns: Decorated function with additional execution methods """ # Initialize default values on_record_error = on_record_error or [] on_record_complete = on_record_complete or [] initial_state_values = initial_state_values or {} # Create configuration config = FactoryMasterConfig( storage=storage, batch_size=batch_size, target_count=target_count, dead_queue_threshold=dead_queue_threshold, max_concurrency=max_concurrency, show_progress=show_progress, task_runner_timeout=task_runner_timeout, on_record_complete=on_record_complete, on_record_error=on_record_error, job_run_stop_threshold=job_run_stop_threshold, ) # Initialize factory instance _factory = None def decorator(func: Callable[P, T]) -> DataFactoryProtocol[P, T]: """Actual decorator that wraps the function.""" nonlocal _factory _factory = _initialize_or_update_factory(_factory, config, func, initial_state_values) wrapper = FactoryWrapper(_factory, func) return cast(DataFactoryProtocol[P, T], wrapper) # Add resume capability as a static method data_factory.resume_from_checkpoint = resume_from_checkpoint return decorator def _initialize_or_update_factory( factory: Optional[Factory], config: FactoryMasterConfig, func: Callable[P, T], initial_state_values: Dict[str, Any] ) -> Factory: """Initialize or update a Factory instance.""" if factory is None: factory = Factory(config, func) factory.state = MutableSharedState(initial_data=initial_state_values) else: factory.config = config factory.func = func factory.state = MutableSharedState(initial_data=initial_state_values) return factory def resume_from_checkpoint(*args, **kwargs) -> List[dict[str, Any]]: """Decorator for creating data processing pipelines. Args: master_job_id : resume for this master job storage: Storage backend to use ('local' or 'in_memory') batch_size: Number of records to process in each batch target_count: Target number of records to generate (0 means process all input) max_concurrency: Maximum number of concurrent tasks initial_state_values: Initial values for shared state on_record_complete: Callbacks to execute after successful record processing on_record_error: Callbacks to execute after failed record processing show_progress: Whether to display progress bar task_runner_timeout: Timeout in seconds for task execution job_run_stop_threshold: Threshold for stopping job if too many records fail Returns: List[Dict(str,Any)] """ return FactoryExecutorManager.resume(*args, **kwargs)